Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
802650ff
"tests/vscode:/vscode.git/clone" did not exist on "73acebb8cfbd1d2954cabe1af4185f9994e61917"
Unverified
Commit
802650ff
authored
Aug 12, 2022
by
Yuge Zhang
Committed by
GitHub
Aug 12, 2022
Browse files
Miscellaneous fixes of NAS (v2.9) (#5051)
parent
cd98c48f
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
41 additions
and
22 deletions
+41
-22
examples/nas/multi-trial/nasbench101/network.py
examples/nas/multi-trial/nasbench101/network.py
+1
-1
examples/nas/multi-trial/nasbench201/network.py
examples/nas/multi-trial/nasbench201/network.py
+1
-1
nni/nas/benchmarks/utils.py
nni/nas/benchmarks/utils.py
+6
-1
nni/nas/evaluator/pytorch/cgo/evaluator.py
nni/nas/evaluator/pytorch/cgo/evaluator.py
+2
-2
nni/nas/evaluator/pytorch/lightning.py
nni/nas/evaluator/pytorch/lightning.py
+24
-13
nni/nas/execution/common/graph.py
nni/nas/execution/common/graph.py
+1
-1
nni/nas/nn/pytorch/layers.py
nni/nas/nn/pytorch/layers.py
+3
-0
test/algo/nas/test_cgo_engine.py
test/algo/nas/test_cgo_engine.py
+3
-3
No files found.
examples/nas/multi-trial/nasbench101/network.py
View file @
802650ff
...
@@ -114,7 +114,7 @@ class NasBench101TrainingModule(pl.LightningModule):
...
@@ -114,7 +114,7 @@ class NasBench101TrainingModule(pl.LightningModule):
momentum
=
0.9
,
alpha
=
0.9
,
eps
=
1.0
)
momentum
=
0.9
,
alpha
=
0.9
,
eps
=
1.0
)
return
{
return
{
'optimizer'
:
optimizer
,
'optimizer'
:
optimizer
,
'scheduler'
:
CosineAnnealingLR
(
optimizer
,
self
.
hparams
.
max_epochs
)
'
lr_
scheduler'
:
CosineAnnealingLR
(
optimizer
,
self
.
hparams
.
max_epochs
)
}
}
def
on_validation_epoch_end
(
self
):
def
on_validation_epoch_end
(
self
):
...
...
examples/nas/multi-trial/nasbench201/network.py
View file @
802650ff
...
@@ -103,7 +103,7 @@ class NasBench201TrainingModule(pl.LightningModule):
...
@@ -103,7 +103,7 @@ class NasBench201TrainingModule(pl.LightningModule):
momentum
=
0.9
,
alpha
=
0.9
,
eps
=
1.0
)
momentum
=
0.9
,
alpha
=
0.9
,
eps
=
1.0
)
return
{
return
{
'optimizer'
:
optimizer
,
'optimizer'
:
optimizer
,
'scheduler'
:
CosineAnnealingLR
(
optimizer
,
self
.
hparams
.
max_epochs
)
'
lr_
scheduler'
:
CosineAnnealingLR
(
optimizer
,
self
.
hparams
.
max_epochs
)
}
}
def
on_validation_epoch_end
(
self
):
def
on_validation_epoch_end
(
self
):
...
...
nni/nas/benchmarks/utils.py
View file @
802650ff
...
@@ -31,7 +31,12 @@ def load_benchmark(benchmark: str) -> SqliteExtDatabase:
...
@@ -31,7 +31,12 @@ def load_benchmark(benchmark: str) -> SqliteExtDatabase:
return
_loaded_benchmarks
[
benchmark
]
return
_loaded_benchmarks
[
benchmark
]
url
=
DB_URLS
[
benchmark
]
url
=
DB_URLS
[
benchmark
]
local_path
=
os
.
path
.
join
(
DATABASE_DIR
,
os
.
path
.
basename
(
url
))
local_path
=
os
.
path
.
join
(
DATABASE_DIR
,
os
.
path
.
basename
(
url
))
load_or_download_file
(
local_path
,
url
)
try
:
load_or_download_file
(
local_path
,
url
)
except
FileNotFoundError
:
raise
FileNotFoundError
(
f
'Please use `nni.nas.benchmarks.download_benchmark("
{
benchmark
}
")` to setup the benchmark first before using it.'
)
_loaded_benchmarks
[
benchmark
]
=
SqliteExtDatabase
(
local_path
,
autoconnect
=
True
)
_loaded_benchmarks
[
benchmark
]
=
SqliteExtDatabase
(
local_path
,
autoconnect
=
True
)
return
_loaded_benchmarks
[
benchmark
]
return
_loaded_benchmarks
[
benchmark
]
...
...
nni/nas/evaluator/pytorch/cgo/evaluator.py
View file @
802650ff
...
@@ -12,7 +12,7 @@ from torch.utils.data import DataLoader
...
@@ -12,7 +12,7 @@ from torch.utils.data import DataLoader
import
nni
import
nni
from
..lightning
import
LightningModule
,
_
AccuracyWithLogits
,
Lightning
from
..lightning
import
LightningModule
,
AccuracyWithLogits
,
Lightning
from
.trainer
import
Trainer
from
.trainer
import
Trainer
__all__
=
[
__all__
=
[
...
@@ -148,7 +148,7 @@ class _ClassificationModule(_MultiModelSupervisedLearningModule):
...
@@ -148,7 +148,7 @@ class _ClassificationModule(_MultiModelSupervisedLearningModule):
learning_rate
:
float
=
0.001
,
learning_rate
:
float
=
0.001
,
weight_decay
:
float
=
0.
,
weight_decay
:
float
=
0.
,
optimizer
:
optim
.
Optimizer
=
optim
.
Adam
):
optimizer
:
optim
.
Optimizer
=
optim
.
Adam
):
super
().
__init__
(
criterion
,
{
'acc'
:
_
AccuracyWithLogits
},
super
().
__init__
(
criterion
,
{
'acc'
:
AccuracyWithLogits
},
learning_rate
=
learning_rate
,
weight_decay
=
weight_decay
,
optimizer
=
optimizer
)
learning_rate
=
learning_rate
,
weight_decay
=
weight_decay
,
optimizer
=
optimizer
)
...
...
nni/nas/evaluator/pytorch/lightning.py
View file @
802650ff
...
@@ -27,7 +27,7 @@ from nni.typehint import Literal
...
@@ -27,7 +27,7 @@ from nni.typehint import Literal
__all__
=
[
__all__
=
[
'LightningModule'
,
'Trainer'
,
'DataLoader'
,
'Lightning'
,
'Classification'
,
'Regression'
,
'LightningModule'
,
'Trainer'
,
'DataLoader'
,
'Lightning'
,
'Classification'
,
'Regression'
,
'
_AccuracyWithLogits'
,
'_
SupervisedLearningModule'
,
'
_
ClassificationModule'
,
'
_
RegressionModule'
,
'SupervisedLearningModule'
,
'ClassificationModule'
,
'RegressionModule'
,
'AccuracyWithLogits'
,
# FIXME: hack to make it importable for tests
# FIXME: hack to make it importable for tests
]
]
...
@@ -102,12 +102,15 @@ class Lightning(Evaluator):
...
@@ -102,12 +102,15 @@ class Lightning(Evaluator):
Used in ``trainer.fit()``. Either a single PyTorch Dataloader or a list of them, specifying validation samples.
Used in ``trainer.fit()``. Either a single PyTorch Dataloader or a list of them, specifying validation samples.
If the ``lightning_module`` has a predefined val_dataloaders method this will be skipped.
If the ``lightning_module`` has a predefined val_dataloaders method this will be skipped.
It can be `any types of dataloader supported by Lightning <https://pytorch-lightning.readthedocs.io/en/stable/guides/data.html>`__.
It can be `any types of dataloader supported by Lightning <https://pytorch-lightning.readthedocs.io/en/stable/guides/data.html>`__.
fit_kwargs
Keyword arguments passed to ``trainer.fit()``.
"""
"""
def
__init__
(
self
,
lightning_module
:
LightningModule
,
trainer
:
Trainer
,
def
__init__
(
self
,
lightning_module
:
LightningModule
,
trainer
:
Trainer
,
train_dataloaders
:
Optional
[
Any
]
=
None
,
train_dataloaders
:
Optional
[
Any
]
=
None
,
val_dataloaders
:
Optional
[
Any
]
=
None
,
val_dataloaders
:
Optional
[
Any
]
=
None
,
train_dataloader
:
Optional
[
Any
]
=
None
):
train_dataloader
:
Optional
[
Any
]
=
None
,
fit_kwargs
:
Optional
[
Dict
[
str
,
Any
]]
=
None
):
assert
isinstance
(
lightning_module
,
LightningModule
),
f
'Lightning module must be an instance of
{
__name__
}
.LightningModule.'
assert
isinstance
(
lightning_module
,
LightningModule
),
f
'Lightning module must be an instance of
{
__name__
}
.LightningModule.'
if
train_dataloader
is
not
None
:
if
train_dataloader
is
not
None
:
warnings
.
warn
(
'`train_dataloader` is deprecated and replaced with `train_dataloaders`.'
,
DeprecationWarning
)
warnings
.
warn
(
'`train_dataloader` is deprecated and replaced with `train_dataloaders`.'
,
DeprecationWarning
)
...
@@ -117,7 +120,7 @@ class Lightning(Evaluator):
...
@@ -117,7 +120,7 @@ class Lightning(Evaluator):
else
:
else
:
# this is not isinstance(trainer, Trainer) because with a different trace call, it can be different
# this is not isinstance(trainer, Trainer) because with a different trace call, it can be different
assert
(
isinstance
(
trainer
,
pl
.
Trainer
)
and
is_traceable
(
trainer
))
or
isinstance
(
trainer
,
cgo_trainer
.
Trainer
),
\
assert
(
isinstance
(
trainer
,
pl
.
Trainer
)
and
is_traceable
(
trainer
))
or
isinstance
(
trainer
,
cgo_trainer
.
Trainer
),
\
f
'Trainer must be imported from
{
__name__
}
or nni.
nas
.evaluator.pytorch.cgo.trainer'
f
'Trainer must be imported from
{
__name__
}
or nni.
retiarii
.evaluator.pytorch.cgo.trainer'
if
not
_check_dataloader
(
train_dataloaders
):
if
not
_check_dataloader
(
train_dataloaders
):
warnings
.
warn
(
f
'Please try to wrap PyTorch DataLoader with nni.trace or '
warnings
.
warn
(
f
'Please try to wrap PyTorch DataLoader with nni.trace or '
f
'import DataLoader from
{
__name__
}
:
{
train_dataloaders
}
'
,
f
'import DataLoader from
{
__name__
}
:
{
train_dataloaders
}
'
,
...
@@ -130,6 +133,7 @@ class Lightning(Evaluator):
...
@@ -130,6 +133,7 @@ class Lightning(Evaluator):
self
.
trainer
=
trainer
self
.
trainer
=
trainer
self
.
train_dataloaders
=
train_dataloaders
self
.
train_dataloaders
=
train_dataloaders
self
.
val_dataloaders
=
val_dataloaders
self
.
val_dataloaders
=
val_dataloaders
self
.
fit_kwargs
=
fit_kwargs
or
{}
@
staticmethod
@
staticmethod
def
_load
(
ir
):
def
_load
(
ir
):
...
@@ -178,7 +182,7 @@ class Lightning(Evaluator):
...
@@ -178,7 +182,7 @@ class Lightning(Evaluator):
The model to fit.
The model to fit.
"""
"""
self
.
module
.
set_model
(
model
)
self
.
module
.
set_model
(
model
)
return
self
.
trainer
.
fit
(
self
.
module
,
self
.
train_dataloaders
,
self
.
val_dataloaders
)
return
self
.
trainer
.
fit
(
self
.
module
,
self
.
train_dataloaders
,
self
.
val_dataloaders
,
**
self
.
fit_kwargs
)
def
_check_dataloader
(
dataloader
):
def
_check_dataloader
(
dataloader
):
...
@@ -194,7 +198,7 @@ def _check_dataloader(dataloader):
...
@@ -194,7 +198,7 @@ def _check_dataloader(dataloader):
### The following are some commonly used Lightning modules ###
### The following are some commonly used Lightning modules ###
class
_
SupervisedLearningModule
(
LightningModule
):
class
SupervisedLearningModule
(
LightningModule
):
trainer
:
pl
.
Trainer
trainer
:
pl
.
Trainer
...
@@ -273,19 +277,19 @@ class _SupervisedLearningModule(LightningModule):
...
@@ -273,19 +277,19 @@ class _SupervisedLearningModule(LightningModule):
return
{
name
:
self
.
trainer
.
callback_metrics
[
'val_'
+
name
].
item
()
for
name
in
self
.
metrics
}
return
{
name
:
self
.
trainer
.
callback_metrics
[
'val_'
+
name
].
item
()
for
name
in
self
.
metrics
}
class
_
AccuracyWithLogits
(
torchmetrics
.
Accuracy
):
class
AccuracyWithLogits
(
torchmetrics
.
Accuracy
):
def
update
(
self
,
pred
,
target
):
def
update
(
self
,
pred
,
target
):
return
super
().
update
(
nn_functional
.
softmax
(
pred
,
dim
=-
1
),
target
)
return
super
().
update
(
nn_functional
.
softmax
(
pred
,
dim
=-
1
),
target
)
@
nni
.
trace
@
nni
.
trace
class
_
ClassificationModule
(
_
SupervisedLearningModule
):
class
ClassificationModule
(
SupervisedLearningModule
):
def
__init__
(
self
,
criterion
:
Type
[
nn
.
Module
]
=
nn
.
CrossEntropyLoss
,
def
__init__
(
self
,
criterion
:
Type
[
nn
.
Module
]
=
nn
.
CrossEntropyLoss
,
learning_rate
:
float
=
0.001
,
learning_rate
:
float
=
0.001
,
weight_decay
:
float
=
0.
,
weight_decay
:
float
=
0.
,
optimizer
:
Type
[
optim
.
Optimizer
]
=
optim
.
Adam
,
optimizer
:
Type
[
optim
.
Optimizer
]
=
optim
.
Adam
,
export_onnx
:
bool
=
True
):
export_onnx
:
bool
=
True
):
super
().
__init__
(
criterion
,
{
'acc'
:
_
AccuracyWithLogits
},
super
().
__init__
(
criterion
,
{
'acc'
:
AccuracyWithLogits
},
learning_rate
=
learning_rate
,
weight_decay
=
weight_decay
,
optimizer
=
optimizer
,
learning_rate
=
learning_rate
,
weight_decay
=
weight_decay
,
optimizer
=
optimizer
,
export_onnx
=
export_onnx
)
export_onnx
=
export_onnx
)
...
@@ -341,14 +345,14 @@ class Classification(Lightning):
...
@@ -341,14 +345,14 @@ class Classification(Lightning):
if
train_dataloader
is
not
None
:
if
train_dataloader
is
not
None
:
warnings
.
warn
(
'`train_dataloader` is deprecated and replaced with `train_dataloaders`.'
,
DeprecationWarning
)
warnings
.
warn
(
'`train_dataloader` is deprecated and replaced with `train_dataloaders`.'
,
DeprecationWarning
)
train_dataloaders
=
train_dataloader
train_dataloaders
=
train_dataloader
module
=
_
ClassificationModule
(
criterion
=
criterion
,
learning_rate
=
learning_rate
,
module
=
ClassificationModule
(
criterion
=
criterion
,
learning_rate
=
learning_rate
,
weight_decay
=
weight_decay
,
optimizer
=
optimizer
,
export_onnx
=
export_onnx
)
weight_decay
=
weight_decay
,
optimizer
=
optimizer
,
export_onnx
=
export_onnx
)
super
().
__init__
(
module
,
Trainer
(
**
trainer_kwargs
),
super
().
__init__
(
module
,
Trainer
(
**
trainer_kwargs
),
train_dataloaders
=
train_dataloaders
,
val_dataloaders
=
val_dataloaders
)
train_dataloaders
=
train_dataloaders
,
val_dataloaders
=
val_dataloaders
)
@
nni
.
trace
@
nni
.
trace
class
_
RegressionModule
(
_
SupervisedLearningModule
):
class
RegressionModule
(
SupervisedLearningModule
):
def
__init__
(
self
,
criterion
:
Type
[
nn
.
Module
]
=
nn
.
MSELoss
,
def
__init__
(
self
,
criterion
:
Type
[
nn
.
Module
]
=
nn
.
MSELoss
,
learning_rate
:
float
=
0.001
,
learning_rate
:
float
=
0.001
,
weight_decay
:
float
=
0.
,
weight_decay
:
float
=
0.
,
...
@@ -406,7 +410,14 @@ class Regression(Lightning):
...
@@ -406,7 +410,14 @@ class Regression(Lightning):
if
train_dataloader
is
not
None
:
if
train_dataloader
is
not
None
:
warnings
.
warn
(
'`train_dataloader` is deprecated and replaced with `train_dataloaders`.'
,
DeprecationWarning
)
warnings
.
warn
(
'`train_dataloader` is deprecated and replaced with `train_dataloaders`.'
,
DeprecationWarning
)
train_dataloaders
=
train_dataloader
train_dataloaders
=
train_dataloader
module
=
_
RegressionModule
(
criterion
=
criterion
,
learning_rate
=
learning_rate
,
module
=
RegressionModule
(
criterion
=
criterion
,
learning_rate
=
learning_rate
,
weight_decay
=
weight_decay
,
optimizer
=
optimizer
,
export_onnx
=
export_onnx
)
weight_decay
=
weight_decay
,
optimizer
=
optimizer
,
export_onnx
=
export_onnx
)
super
().
__init__
(
module
,
Trainer
(
**
trainer_kwargs
),
super
().
__init__
(
module
,
Trainer
(
**
trainer_kwargs
),
train_dataloaders
=
train_dataloaders
,
val_dataloaders
=
val_dataloaders
)
train_dataloaders
=
train_dataloaders
,
val_dataloaders
=
val_dataloaders
)
# Alias for backwards compatibility
_SupervisedLearningModule
=
SupervisedLearningModule
_AccuracyWithLogits
=
AccuracyWithLogits
_ClassificationModule
=
ClassificationModule
_RegressionModule
=
RegressionModule
nni/nas/execution/common/graph.py
View file @
802650ff
...
@@ -13,7 +13,7 @@ from typing import (TYPE_CHECKING, Any, Dict, Iterable, List,
...
@@ -13,7 +13,7 @@ from typing import (TYPE_CHECKING, Any, Dict, Iterable, List,
Optional
,
Set
,
Tuple
,
Type
,
Union
,
cast
,
overload
)
Optional
,
Set
,
Tuple
,
Type
,
Union
,
cast
,
overload
)
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
.muta
tor
import
Mutator
from
nni.nas
.muta
ble
import
Mutator
from
nni.nas.evaluator
import
Evaluator
from
nni.nas.evaluator
import
Evaluator
from
nni.nas.utils
import
uid
from
nni.nas.utils
import
uid
...
...
nni/nas/nn/pytorch/layers.py
View file @
802650ff
# Copyright (c) Microsoft Corporation.
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# Licensed under the MIT license.
# If you've seen lint errors like `"Sequential" is not a known member of module`,
# please run `python test/vso_tools/trigger_import.py` to generate `_layers.py`.
from
pathlib
import
Path
from
pathlib
import
Path
# To make auto-completion happy, we generate a _layers.py that lists out all the classes.
# To make auto-completion happy, we generate a _layers.py that lists out all the classes.
...
...
test/algo/nas/test_cgo_engine.py
View file @
802650ff
...
@@ -152,7 +152,7 @@ def _new_trainer():
...
@@ -152,7 +152,7 @@ def _new_trainer():
train_dataset
=
serialize
(
MNIST
,
root
=
'data/mnist'
,
train
=
True
,
download
=
True
,
transform
=
transform
)
train_dataset
=
serialize
(
MNIST
,
root
=
'data/mnist'
,
train
=
True
,
download
=
True
,
transform
=
transform
)
test_dataset
=
serialize
(
MNIST
,
root
=
'data/mnist'
,
train
=
False
,
download
=
True
,
transform
=
transform
)
test_dataset
=
serialize
(
MNIST
,
root
=
'data/mnist'
,
train
=
False
,
download
=
True
,
transform
=
transform
)
multi_module
=
_MultiModelSupervisedLearningModule
(
nn
.
CrossEntropyLoss
,
{
'acc'
:
pl
.
_
AccuracyWithLogits
})
multi_module
=
_MultiModelSupervisedLearningModule
(
nn
.
CrossEntropyLoss
,
{
'acc'
:
pl
.
AccuracyWithLogits
})
lightning
=
pl
.
Lightning
(
multi_module
,
cgo_trainer
.
Trainer
(
use_cgo
=
True
,
lightning
=
pl
.
Lightning
(
multi_module
,
cgo_trainer
.
Trainer
(
use_cgo
=
True
,
max_epochs
=
1
,
max_epochs
=
1
,
...
@@ -201,7 +201,7 @@ class CGOEngineTest(unittest.TestCase):
...
@@ -201,7 +201,7 @@ class CGOEngineTest(unittest.TestCase):
train_dataset
=
serialize
(
MNIST
,
root
=
'data/mnist'
,
train
=
True
,
download
=
True
,
transform
=
transform
)
train_dataset
=
serialize
(
MNIST
,
root
=
'data/mnist'
,
train
=
True
,
download
=
True
,
transform
=
transform
)
test_dataset
=
serialize
(
MNIST
,
root
=
'data/mnist'
,
train
=
False
,
download
=
True
,
transform
=
transform
)
test_dataset
=
serialize
(
MNIST
,
root
=
'data/mnist'
,
train
=
False
,
download
=
True
,
transform
=
transform
)
multi_module
=
_MultiModelSupervisedLearningModule
(
nn
.
CrossEntropyLoss
,
{
'acc'
:
pl
.
_
AccuracyWithLogits
},
n_models
=
2
)
multi_module
=
_MultiModelSupervisedLearningModule
(
nn
.
CrossEntropyLoss
,
{
'acc'
:
pl
.
AccuracyWithLogits
},
n_models
=
2
)
lightning
=
pl
.
Lightning
(
multi_module
,
cgo_trainer
.
Trainer
(
use_cgo
=
True
,
lightning
=
pl
.
Lightning
(
multi_module
,
cgo_trainer
.
Trainer
(
use_cgo
=
True
,
max_epochs
=
1
,
max_epochs
=
1
,
...
@@ -225,7 +225,7 @@ class CGOEngineTest(unittest.TestCase):
...
@@ -225,7 +225,7 @@ class CGOEngineTest(unittest.TestCase):
train_dataset
=
serialize
(
MNIST
,
root
=
'data/mnist'
,
train
=
True
,
download
=
True
,
transform
=
transform
)
train_dataset
=
serialize
(
MNIST
,
root
=
'data/mnist'
,
train
=
True
,
download
=
True
,
transform
=
transform
)
test_dataset
=
serialize
(
MNIST
,
root
=
'data/mnist'
,
train
=
False
,
download
=
True
,
transform
=
transform
)
test_dataset
=
serialize
(
MNIST
,
root
=
'data/mnist'
,
train
=
False
,
download
=
True
,
transform
=
transform
)
multi_module
=
_MultiModelSupervisedLearningModule
(
nn
.
CrossEntropyLoss
,
{
'acc'
:
pl
.
_
AccuracyWithLogits
},
n_models
=
2
)
multi_module
=
_MultiModelSupervisedLearningModule
(
nn
.
CrossEntropyLoss
,
{
'acc'
:
pl
.
AccuracyWithLogits
},
n_models
=
2
)
lightning
=
pl
.
Lightning
(
multi_module
,
cgo_trainer
.
Trainer
(
use_cgo
=
True
,
lightning
=
pl
.
Lightning
(
multi_module
,
cgo_trainer
.
Trainer
(
use_cgo
=
True
,
max_epochs
=
1
,
max_epochs
=
1
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment