Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
5874c27f
Unverified
Commit
5874c27f
authored
Sep 02, 2022
by
Yuge Zhang
Committed by
GitHub
Sep 02, 2022
Browse files
Support loading supernet checkpoint in lightning (#5096)
parent
79a51d41
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
65 additions
and
6 deletions
+65
-6
nni/nas/evaluator/pytorch/lightning.py
nni/nas/evaluator/pytorch/lightning.py
+17
-1
nni/nas/fixed.py
nni/nas/fixed.py
+32
-0
nni/nas/hub/pytorch/autoformer.py
nni/nas/hub/pytorch/autoformer.py
+7
-3
nni/nas/oneshot/pytorch/base_lightning.py
nni/nas/oneshot/pytorch/base_lightning.py
+6
-0
nni/nas/oneshot/pytorch/strategy.py
nni/nas/oneshot/pytorch/strategy.py
+1
-0
nni/nas/utils/misc.py
nni/nas/utils/misc.py
+2
-2
No files found.
nni/nas/evaluator/pytorch/lightning.py
View file @
5874c27f
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import
logging
import
os
import
warnings
from
pathlib
import
Path
...
...
@@ -31,6 +32,8 @@ __all__ = [
# FIXME: hack to make it importable for tests
]
_logger
=
logging
.
getLogger
(
__name__
)
class
LightningModule
(
pl
.
LightningModule
):
"""
...
...
@@ -175,6 +178,7 @@ class Lightning(Evaluator):
def
fit
(
self
,
model
):
"""
Fit the model with provided dataloader, with Lightning trainer.
If ``train_dataloaders`` is not provided, ``trainer.validate()`` will be called.
Parameters
----------
...
...
@@ -182,7 +186,13 @@ class Lightning(Evaluator):
The model to fit.
"""
self
.
module
.
set_model
(
model
)
return
self
.
trainer
.
fit
(
self
.
module
,
self
.
train_dataloaders
,
self
.
val_dataloaders
,
**
self
.
fit_kwargs
)
if
self
.
train_dataloaders
is
None
:
_logger
.
info
(
'Train dataloaders are missing. Skip to validation.'
)
return
self
.
trainer
.
validate
(
self
.
module
,
self
.
val_dataloaders
,
**
self
.
fit_kwargs
)
else
:
if
self
.
val_dataloaders
is
None
:
_logger
.
warning
(
'Validation dataloaders are missing.'
)
return
self
.
trainer
.
fit
(
self
.
module
,
self
.
train_dataloaders
,
self
.
val_dataloaders
,
**
self
.
fit_kwargs
)
def
_check_dataloader
(
dataloader
):
...
...
@@ -265,6 +275,12 @@ class SupervisedLearningModule(LightningModule):
nni
.
report_intermediate_result
(
self
.
_get_validation_metrics
())
def
on_fit_end
(
self
):
self
.
_final_report
()
def
on_validation_end
(
self
):
self
.
_final_report
()
def
_final_report
(
self
):
if
self
.
running_mode
==
'multi'
and
nni
.
get_current_parameter
()
is
not
None
:
nni
.
report_final_result
(
self
.
_get_validation_metrics
())
...
...
nni/nas/fixed.py
View file @
5874c27f
...
...
@@ -3,6 +3,7 @@
import
json
import
logging
from
contextlib
import
contextmanager
from
pathlib
import
Path
from
typing
import
Union
,
Dict
,
Any
...
...
@@ -41,3 +42,34 @@ def fixed_arch(fixed_arch: Union[str, Path, Dict[str, Any]], verbose=True):
_logger
.
info
(
f
'Fixed architecture: %s'
,
fixed_arch
)
return
ContextStack
(
'fixed'
,
fixed_arch
)
@
contextmanager
def
no_fixed_arch
():
"""
Ignore the ``fixed_arch()`` context.
This is useful in creating a search space within a ``fixed_arch()`` context.
Under the hood, it only disables the most recent one fixed context, which means,
if it's currently in a nested with-fixed-arch context, multiple ``no_fixed_arch()`` contexts is required.
Examples
--------
>>> with fixed_arch(arch_dict):
... with no_fixed_arch():
... model_space = ModelSpace()
"""
NO_ARCH
=
'_no_arch_'
popped_arch
=
NO_ARCH
# make linter happy
try
:
try
:
popped_arch
=
ContextStack
.
pop
(
'fixed'
)
except
IndexError
:
# context unavailable
popped_arch
=
NO_ARCH
yield
finally
:
if
popped_arch
is
not
NO_ARCH
:
ContextStack
.
push
(
'fixed'
,
popped_arch
)
nni/nas/hub/pytorch/autoformer.py
View file @
5874c27f
...
...
@@ -8,6 +8,7 @@ import torch.nn.functional as F
import
nni.nas.nn.pytorch
as
nn
from
nni.nas
import
model_wrapper
,
basic_unit
from
nni.nas.fixed
import
no_fixed_arch
from
nni.nas.nn.pytorch.choice
import
ValueChoiceX
from
nni.nas.oneshot.pytorch.supermodule.operation
import
MixedOperation
from
nni.nas.oneshot.pytorch.supermodule._valuechoice_utils
import
traverse_all_options
...
...
@@ -432,7 +433,7 @@ class AutoformerSpace(nn.Module):
@
classmethod
def
load_strategy_checkpoint
(
cls
,
name
:
str
,
download
:
bool
=
True
,
progress
:
bool
=
True
):
"""
Load the
RandomOneShot strategy initialized with supernet weigh
ts.
Load the
related strategy checkpoin
ts.
Parameters
----------
...
...
@@ -446,15 +447,18 @@ class AutoformerSpace(nn.Module):
Returns
-------
BaseStrategy
The
RandomOneShot strategy initialized with supernet weights provided in the official repo
.
The
loaded strategy
.
"""
legal
=
[
'random-one-shot-tiny'
,
'random-one-shot-small'
,
'random-one-shot-base'
]
if
name
not
in
legal
:
raise
ValueError
(
f
'Unsupported name:
{
name
}
. It should be one of
{
legal
}
.'
)
name
=
name
[
16
:]
# RandomOneShot is the only supported strategy for now.
from
nni.nas.strategy
import
RandomOneShot
init_kwargs
=
cls
.
preset
(
name
)
model_sapce
=
cls
(
**
init_kwargs
)
with
no_fixed_arch
():
model_sapce
=
cls
(
**
init_kwargs
)
strategy
=
RandomOneShot
(
mutation_hooks
=
cls
.
get_extra_mutation_hooks
())
strategy
.
attach_model
(
model_sapce
)
weight_file
=
load_pretrained_weight
(
f
"autoformer-
{
name
}
-supernet"
,
download
=
download
,
progress
=
progress
)
...
...
nni/nas/oneshot/pytorch/base_lightning.py
View file @
5874c27f
...
...
@@ -519,6 +519,12 @@ class BaseOneShotLightningModule(pl.LightningModule):
def
on_train_end
(
self
):
return
self
.
model
.
on_train_end
()
def
on_validation_start
(
self
):
return
self
.
model
.
on_validation_start
()
def
on_validation_end
(
self
):
return
self
.
model
.
on_validation_end
()
def
on_fit_start
(
self
):
return
self
.
model
.
on_fit_start
()
...
...
nni/nas/oneshot/pytorch/strategy.py
View file @
5874c27f
...
...
@@ -61,6 +61,7 @@ class OneShotStrategy(BaseStrategy):
evaluator_module
.
running_mode
=
'oneshot'
evaluator_module
.
set_model
(
py_model
)
else
:
# FIXME: this should be an evaluator + model
from
nni.retiarii.evaluator.pytorch.lightning
import
ClassificationModule
evaluator_module
=
ClassificationModule
()
evaluator_module
.
running_mode
=
'oneshot'
...
...
nni/nas/utils/misc.py
View file @
5874c27f
...
...
@@ -106,8 +106,8 @@ class ContextStack:
cls
.
_stack
[
key
].
append
(
value
)
@
classmethod
def
pop
(
cls
,
key
:
str
)
->
None
:
cls
.
_stack
[
key
].
pop
()
def
pop
(
cls
,
key
:
str
)
->
Any
:
return
cls
.
_stack
[
key
].
pop
()
@
classmethod
def
top
(
cls
,
key
:
str
)
->
Any
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment