"library/vscode:/vscode.git/clone" did not exist on "ba251e4a1139911cce446509a498a01c326c377c"
Unverified Commit 5874c27f authored by Yuge Zhang's avatar Yuge Zhang Committed by GitHub
Browse files

Support loading supernet checkpoint in lightning (#5096)

parent 79a51d41
# Copyright (c) Microsoft Corporation. # Copyright (c) Microsoft Corporation.
# Licensed under the MIT license. # Licensed under the MIT license.
import logging
import os import os
import warnings import warnings
from pathlib import Path from pathlib import Path
...@@ -31,6 +32,8 @@ __all__ = [ ...@@ -31,6 +32,8 @@ __all__ = [
# FIXME: hack to make it importable for tests # FIXME: hack to make it importable for tests
] ]
_logger = logging.getLogger(__name__)
class LightningModule(pl.LightningModule): class LightningModule(pl.LightningModule):
""" """
...@@ -175,6 +178,7 @@ class Lightning(Evaluator): ...@@ -175,6 +178,7 @@ class Lightning(Evaluator):
def fit(self, model): def fit(self, model):
""" """
Fit the model with provided dataloader, with Lightning trainer. Fit the model with provided dataloader, with Lightning trainer.
If ``train_dataloaders`` is not provided, ``trainer.validate()`` will be called.
Parameters Parameters
---------- ----------
...@@ -182,6 +186,12 @@ class Lightning(Evaluator): ...@@ -182,6 +186,12 @@ class Lightning(Evaluator):
The model to fit. The model to fit.
""" """
self.module.set_model(model) self.module.set_model(model)
if self.train_dataloaders is None:
_logger.info('Train dataloaders are missing. Skip to validation.')
return self.trainer.validate(self.module, self.val_dataloaders, **self.fit_kwargs)
else:
if self.val_dataloaders is None:
_logger.warning('Validation dataloaders are missing.')
return self.trainer.fit(self.module, self.train_dataloaders, self.val_dataloaders, **self.fit_kwargs) return self.trainer.fit(self.module, self.train_dataloaders, self.val_dataloaders, **self.fit_kwargs)
...@@ -265,6 +275,12 @@ class SupervisedLearningModule(LightningModule): ...@@ -265,6 +275,12 @@ class SupervisedLearningModule(LightningModule):
nni.report_intermediate_result(self._get_validation_metrics()) nni.report_intermediate_result(self._get_validation_metrics())
def on_fit_end(self): def on_fit_end(self):
self._final_report()
def on_validation_end(self):
self._final_report()
def _final_report(self):
if self.running_mode == 'multi' and nni.get_current_parameter() is not None: if self.running_mode == 'multi' and nni.get_current_parameter() is not None:
nni.report_final_result(self._get_validation_metrics()) nni.report_final_result(self._get_validation_metrics())
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
import json import json
import logging import logging
from contextlib import contextmanager
from pathlib import Path from pathlib import Path
from typing import Union, Dict, Any from typing import Union, Dict, Any
...@@ -41,3 +42,34 @@ def fixed_arch(fixed_arch: Union[str, Path, Dict[str, Any]], verbose=True): ...@@ -41,3 +42,34 @@ def fixed_arch(fixed_arch: Union[str, Path, Dict[str, Any]], verbose=True):
_logger.info(f'Fixed architecture: %s', fixed_arch) _logger.info(f'Fixed architecture: %s', fixed_arch)
return ContextStack('fixed', fixed_arch) return ContextStack('fixed', fixed_arch)
@contextmanager
def no_fixed_arch():
"""
Ignore the ``fixed_arch()`` context.
This is useful in creating a search space within a ``fixed_arch()`` context.
Under the hood, it only disables the most recent one fixed context, which means,
if it's currently in a nested with-fixed-arch context, multiple ``no_fixed_arch()`` contexts is required.
Examples
--------
>>> with fixed_arch(arch_dict):
... with no_fixed_arch():
... model_space = ModelSpace()
"""
NO_ARCH = '_no_arch_'
popped_arch = NO_ARCH # make linter happy
try:
try:
popped_arch = ContextStack.pop('fixed')
except IndexError:
# context unavailable
popped_arch = NO_ARCH
yield
finally:
if popped_arch is not NO_ARCH:
ContextStack.push('fixed', popped_arch)
...@@ -8,6 +8,7 @@ import torch.nn.functional as F ...@@ -8,6 +8,7 @@ import torch.nn.functional as F
import nni.nas.nn.pytorch as nn import nni.nas.nn.pytorch as nn
from nni.nas import model_wrapper, basic_unit from nni.nas import model_wrapper, basic_unit
from nni.nas.fixed import no_fixed_arch
from nni.nas.nn.pytorch.choice import ValueChoiceX from nni.nas.nn.pytorch.choice import ValueChoiceX
from nni.nas.oneshot.pytorch.supermodule.operation import MixedOperation from nni.nas.oneshot.pytorch.supermodule.operation import MixedOperation
from nni.nas.oneshot.pytorch.supermodule._valuechoice_utils import traverse_all_options from nni.nas.oneshot.pytorch.supermodule._valuechoice_utils import traverse_all_options
...@@ -432,7 +433,7 @@ class AutoformerSpace(nn.Module): ...@@ -432,7 +433,7 @@ class AutoformerSpace(nn.Module):
@classmethod @classmethod
def load_strategy_checkpoint(cls, name: str, download: bool = True, progress: bool = True): def load_strategy_checkpoint(cls, name: str, download: bool = True, progress: bool = True):
""" """
Load the RandomOneShot strategy initialized with supernet weights. Load the related strategy checkpoints.
Parameters Parameters
---------- ----------
...@@ -446,14 +447,17 @@ class AutoformerSpace(nn.Module): ...@@ -446,14 +447,17 @@ class AutoformerSpace(nn.Module):
Returns Returns
------- -------
BaseStrategy BaseStrategy
The RandomOneShot strategy initialized with supernet weights provided in the official repo. The loaded strategy.
""" """
legal = ['random-one-shot-tiny', 'random-one-shot-small', 'random-one-shot-base'] legal = ['random-one-shot-tiny', 'random-one-shot-small', 'random-one-shot-base']
if name not in legal: if name not in legal:
raise ValueError(f'Unsupported name: {name}. It should be one of {legal}.') raise ValueError(f'Unsupported name: {name}. It should be one of {legal}.')
name = name[16:] name = name[16:]
# RandomOneShot is the only supported strategy for now.
from nni.nas.strategy import RandomOneShot from nni.nas.strategy import RandomOneShot
init_kwargs = cls.preset(name) init_kwargs = cls.preset(name)
with no_fixed_arch():
model_sapce = cls(**init_kwargs) model_sapce = cls(**init_kwargs)
strategy = RandomOneShot(mutation_hooks=cls.get_extra_mutation_hooks()) strategy = RandomOneShot(mutation_hooks=cls.get_extra_mutation_hooks())
strategy.attach_model(model_sapce) strategy.attach_model(model_sapce)
......
...@@ -519,6 +519,12 @@ class BaseOneShotLightningModule(pl.LightningModule): ...@@ -519,6 +519,12 @@ class BaseOneShotLightningModule(pl.LightningModule):
def on_train_end(self): def on_train_end(self):
return self.model.on_train_end() return self.model.on_train_end()
def on_validation_start(self):
return self.model.on_validation_start()
def on_validation_end(self):
return self.model.on_validation_end()
def on_fit_start(self): def on_fit_start(self):
return self.model.on_fit_start() return self.model.on_fit_start()
......
...@@ -61,6 +61,7 @@ class OneShotStrategy(BaseStrategy): ...@@ -61,6 +61,7 @@ class OneShotStrategy(BaseStrategy):
evaluator_module.running_mode = 'oneshot' evaluator_module.running_mode = 'oneshot'
evaluator_module.set_model(py_model) evaluator_module.set_model(py_model)
else: else:
# FIXME: this should be an evaluator + model
from nni.retiarii.evaluator.pytorch.lightning import ClassificationModule from nni.retiarii.evaluator.pytorch.lightning import ClassificationModule
evaluator_module = ClassificationModule() evaluator_module = ClassificationModule()
evaluator_module.running_mode = 'oneshot' evaluator_module.running_mode = 'oneshot'
......
...@@ -106,8 +106,8 @@ class ContextStack: ...@@ -106,8 +106,8 @@ class ContextStack:
cls._stack[key].append(value) cls._stack[key].append(value)
@classmethod @classmethod
def pop(cls, key: str) -> None: def pop(cls, key: str) -> Any:
cls._stack[key].pop() return cls._stack[key].pop()
@classmethod @classmethod
def top(cls, key: str) -> Any: def top(cls, key: str) -> Any:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment