Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
5874c27f
"library/vscode:/vscode.git/clone" did not exist on "ba251e4a1139911cce446509a498a01c326c377c"
Unverified
Commit
5874c27f
authored
Sep 02, 2022
by
Yuge Zhang
Committed by
GitHub
Sep 02, 2022
Browse files
Support loading supernet checkpoint in lightning (#5096)
parent
79a51d41
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
65 additions
and
6 deletions
+65
-6
nni/nas/evaluator/pytorch/lightning.py
nni/nas/evaluator/pytorch/lightning.py
+17
-1
nni/nas/fixed.py
nni/nas/fixed.py
+32
-0
nni/nas/hub/pytorch/autoformer.py
nni/nas/hub/pytorch/autoformer.py
+7
-3
nni/nas/oneshot/pytorch/base_lightning.py
nni/nas/oneshot/pytorch/base_lightning.py
+6
-0
nni/nas/oneshot/pytorch/strategy.py
nni/nas/oneshot/pytorch/strategy.py
+1
-0
nni/nas/utils/misc.py
nni/nas/utils/misc.py
+2
-2
No files found.
nni/nas/evaluator/pytorch/lightning.py
View file @
5874c27f
# Copyright (c) Microsoft Corporation.
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# Licensed under the MIT license.
import
logging
import
os
import
os
import
warnings
import
warnings
from
pathlib
import
Path
from
pathlib
import
Path
...
@@ -31,6 +32,8 @@ __all__ = [
...
@@ -31,6 +32,8 @@ __all__ = [
# FIXME: hack to make it importable for tests
# FIXME: hack to make it importable for tests
]
]
_logger
=
logging
.
getLogger
(
__name__
)
class
LightningModule
(
pl
.
LightningModule
):
class
LightningModule
(
pl
.
LightningModule
):
"""
"""
...
@@ -175,6 +178,7 @@ class Lightning(Evaluator):
...
@@ -175,6 +178,7 @@ class Lightning(Evaluator):
def
fit
(
self
,
model
):
def
fit
(
self
,
model
):
"""
"""
Fit the model with provided dataloader, with Lightning trainer.
Fit the model with provided dataloader, with Lightning trainer.
If ``train_dataloaders`` is not provided, ``trainer.validate()`` will be called.
Parameters
Parameters
----------
----------
...
@@ -182,6 +186,12 @@ class Lightning(Evaluator):
...
@@ -182,6 +186,12 @@ class Lightning(Evaluator):
The model to fit.
The model to fit.
"""
"""
self
.
module
.
set_model
(
model
)
self
.
module
.
set_model
(
model
)
if
self
.
train_dataloaders
is
None
:
_logger
.
info
(
'Train dataloaders are missing. Skip to validation.'
)
return
self
.
trainer
.
validate
(
self
.
module
,
self
.
val_dataloaders
,
**
self
.
fit_kwargs
)
else
:
if
self
.
val_dataloaders
is
None
:
_logger
.
warning
(
'Validation dataloaders are missing.'
)
return
self
.
trainer
.
fit
(
self
.
module
,
self
.
train_dataloaders
,
self
.
val_dataloaders
,
**
self
.
fit_kwargs
)
return
self
.
trainer
.
fit
(
self
.
module
,
self
.
train_dataloaders
,
self
.
val_dataloaders
,
**
self
.
fit_kwargs
)
...
@@ -265,6 +275,12 @@ class SupervisedLearningModule(LightningModule):
...
@@ -265,6 +275,12 @@ class SupervisedLearningModule(LightningModule):
nni
.
report_intermediate_result
(
self
.
_get_validation_metrics
())
nni
.
report_intermediate_result
(
self
.
_get_validation_metrics
())
def
on_fit_end
(
self
):
def
on_fit_end
(
self
):
self
.
_final_report
()
def
on_validation_end
(
self
):
self
.
_final_report
()
def
_final_report
(
self
):
if
self
.
running_mode
==
'multi'
and
nni
.
get_current_parameter
()
is
not
None
:
if
self
.
running_mode
==
'multi'
and
nni
.
get_current_parameter
()
is
not
None
:
nni
.
report_final_result
(
self
.
_get_validation_metrics
())
nni
.
report_final_result
(
self
.
_get_validation_metrics
())
...
...
nni/nas/fixed.py
View file @
5874c27f
...
@@ -3,6 +3,7 @@
...
@@ -3,6 +3,7 @@
import
json
import
json
import
logging
import
logging
from
contextlib
import
contextmanager
from
pathlib
import
Path
from
pathlib
import
Path
from
typing
import
Union
,
Dict
,
Any
from
typing
import
Union
,
Dict
,
Any
...
@@ -41,3 +42,34 @@ def fixed_arch(fixed_arch: Union[str, Path, Dict[str, Any]], verbose=True):
...
@@ -41,3 +42,34 @@ def fixed_arch(fixed_arch: Union[str, Path, Dict[str, Any]], verbose=True):
_logger
.
info
(
f
'Fixed architecture: %s'
,
fixed_arch
)
_logger
.
info
(
f
'Fixed architecture: %s'
,
fixed_arch
)
return
ContextStack
(
'fixed'
,
fixed_arch
)
return
ContextStack
(
'fixed'
,
fixed_arch
)
@
contextmanager
def
no_fixed_arch
():
"""
Ignore the ``fixed_arch()`` context.
This is useful in creating a search space within a ``fixed_arch()`` context.
Under the hood, it only disables the most recent one fixed context, which means,
if it's currently in a nested with-fixed-arch context, multiple ``no_fixed_arch()`` contexts is required.
Examples
--------
>>> with fixed_arch(arch_dict):
... with no_fixed_arch():
... model_space = ModelSpace()
"""
NO_ARCH
=
'_no_arch_'
popped_arch
=
NO_ARCH
# make linter happy
try
:
try
:
popped_arch
=
ContextStack
.
pop
(
'fixed'
)
except
IndexError
:
# context unavailable
popped_arch
=
NO_ARCH
yield
finally
:
if
popped_arch
is
not
NO_ARCH
:
ContextStack
.
push
(
'fixed'
,
popped_arch
)
nni/nas/hub/pytorch/autoformer.py
View file @
5874c27f
...
@@ -8,6 +8,7 @@ import torch.nn.functional as F
...
@@ -8,6 +8,7 @@ import torch.nn.functional as F
import
nni.nas.nn.pytorch
as
nn
import
nni.nas.nn.pytorch
as
nn
from
nni.nas
import
model_wrapper
,
basic_unit
from
nni.nas
import
model_wrapper
,
basic_unit
from
nni.nas.fixed
import
no_fixed_arch
from
nni.nas.nn.pytorch.choice
import
ValueChoiceX
from
nni.nas.nn.pytorch.choice
import
ValueChoiceX
from
nni.nas.oneshot.pytorch.supermodule.operation
import
MixedOperation
from
nni.nas.oneshot.pytorch.supermodule.operation
import
MixedOperation
from
nni.nas.oneshot.pytorch.supermodule._valuechoice_utils
import
traverse_all_options
from
nni.nas.oneshot.pytorch.supermodule._valuechoice_utils
import
traverse_all_options
...
@@ -432,7 +433,7 @@ class AutoformerSpace(nn.Module):
...
@@ -432,7 +433,7 @@ class AutoformerSpace(nn.Module):
@
classmethod
@
classmethod
def
load_strategy_checkpoint
(
cls
,
name
:
str
,
download
:
bool
=
True
,
progress
:
bool
=
True
):
def
load_strategy_checkpoint
(
cls
,
name
:
str
,
download
:
bool
=
True
,
progress
:
bool
=
True
):
"""
"""
Load the
RandomOneShot strategy initialized with supernet weigh
ts.
Load the
related strategy checkpoin
ts.
Parameters
Parameters
----------
----------
...
@@ -446,14 +447,17 @@ class AutoformerSpace(nn.Module):
...
@@ -446,14 +447,17 @@ class AutoformerSpace(nn.Module):
Returns
Returns
-------
-------
BaseStrategy
BaseStrategy
The
RandomOneShot strategy initialized with supernet weights provided in the official repo
.
The
loaded strategy
.
"""
"""
legal
=
[
'random-one-shot-tiny'
,
'random-one-shot-small'
,
'random-one-shot-base'
]
legal
=
[
'random-one-shot-tiny'
,
'random-one-shot-small'
,
'random-one-shot-base'
]
if
name
not
in
legal
:
if
name
not
in
legal
:
raise
ValueError
(
f
'Unsupported name:
{
name
}
. It should be one of
{
legal
}
.'
)
raise
ValueError
(
f
'Unsupported name:
{
name
}
. It should be one of
{
legal
}
.'
)
name
=
name
[
16
:]
name
=
name
[
16
:]
# RandomOneShot is the only supported strategy for now.
from
nni.nas.strategy
import
RandomOneShot
from
nni.nas.strategy
import
RandomOneShot
init_kwargs
=
cls
.
preset
(
name
)
init_kwargs
=
cls
.
preset
(
name
)
with
no_fixed_arch
():
model_sapce
=
cls
(
**
init_kwargs
)
model_sapce
=
cls
(
**
init_kwargs
)
strategy
=
RandomOneShot
(
mutation_hooks
=
cls
.
get_extra_mutation_hooks
())
strategy
=
RandomOneShot
(
mutation_hooks
=
cls
.
get_extra_mutation_hooks
())
strategy
.
attach_model
(
model_sapce
)
strategy
.
attach_model
(
model_sapce
)
...
...
nni/nas/oneshot/pytorch/base_lightning.py
View file @
5874c27f
...
@@ -519,6 +519,12 @@ class BaseOneShotLightningModule(pl.LightningModule):
...
@@ -519,6 +519,12 @@ class BaseOneShotLightningModule(pl.LightningModule):
def
on_train_end
(
self
):
def
on_train_end
(
self
):
return
self
.
model
.
on_train_end
()
return
self
.
model
.
on_train_end
()
def
on_validation_start
(
self
):
return
self
.
model
.
on_validation_start
()
def
on_validation_end
(
self
):
return
self
.
model
.
on_validation_end
()
def
on_fit_start
(
self
):
def
on_fit_start
(
self
):
return
self
.
model
.
on_fit_start
()
return
self
.
model
.
on_fit_start
()
...
...
nni/nas/oneshot/pytorch/strategy.py
View file @
5874c27f
...
@@ -61,6 +61,7 @@ class OneShotStrategy(BaseStrategy):
...
@@ -61,6 +61,7 @@ class OneShotStrategy(BaseStrategy):
evaluator_module
.
running_mode
=
'oneshot'
evaluator_module
.
running_mode
=
'oneshot'
evaluator_module
.
set_model
(
py_model
)
evaluator_module
.
set_model
(
py_model
)
else
:
else
:
# FIXME: this should be an evaluator + model
from
nni.retiarii.evaluator.pytorch.lightning
import
ClassificationModule
from
nni.retiarii.evaluator.pytorch.lightning
import
ClassificationModule
evaluator_module
=
ClassificationModule
()
evaluator_module
=
ClassificationModule
()
evaluator_module
.
running_mode
=
'oneshot'
evaluator_module
.
running_mode
=
'oneshot'
...
...
nni/nas/utils/misc.py
View file @
5874c27f
...
@@ -106,8 +106,8 @@ class ContextStack:
...
@@ -106,8 +106,8 @@ class ContextStack:
cls
.
_stack
[
key
].
append
(
value
)
cls
.
_stack
[
key
].
append
(
value
)
@
classmethod
@
classmethod
def
pop
(
cls
,
key
:
str
)
->
None
:
def
pop
(
cls
,
key
:
str
)
->
Any
:
cls
.
_stack
[
key
].
pop
()
return
cls
.
_stack
[
key
].
pop
()
@
classmethod
@
classmethod
def
top
(
cls
,
key
:
str
)
->
Any
:
def
top
(
cls
,
key
:
str
)
->
Any
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment