Unverified Commit 5874c27f authored by Yuge Zhang's avatar Yuge Zhang Committed by GitHub
Browse files

Support loading supernet checkpoint in lightning (#5096)

parent 79a51d41
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import logging
import os
import warnings
from pathlib import Path
......@@ -31,6 +32,8 @@ __all__ = [
# FIXME: hack to make it importable for tests
]
_logger = logging.getLogger(__name__)
class LightningModule(pl.LightningModule):
"""
......@@ -175,6 +178,7 @@ class Lightning(Evaluator):
def fit(self, model):
"""
Fit the model with provided dataloader, with Lightning trainer.
If ``train_dataloaders`` is not provided, ``trainer.validate()`` will be called.
Parameters
----------
......@@ -182,6 +186,12 @@ class Lightning(Evaluator):
The model to fit.
"""
self.module.set_model(model)
if self.train_dataloaders is None:
_logger.info('Train dataloaders are missing. Skip to validation.')
return self.trainer.validate(self.module, self.val_dataloaders, **self.fit_kwargs)
else:
if self.val_dataloaders is None:
_logger.warning('Validation dataloaders are missing.')
return self.trainer.fit(self.module, self.train_dataloaders, self.val_dataloaders, **self.fit_kwargs)
......@@ -265,6 +275,12 @@ class SupervisedLearningModule(LightningModule):
nni.report_intermediate_result(self._get_validation_metrics())
def on_fit_end(self):
self._final_report()
def on_validation_end(self):
self._final_report()
def _final_report(self):
if self.running_mode == 'multi' and nni.get_current_parameter() is not None:
nni.report_final_result(self._get_validation_metrics())
......
......@@ -3,6 +3,7 @@
import json
import logging
from contextlib import contextmanager
from pathlib import Path
from typing import Union, Dict, Any
......@@ -41,3 +42,34 @@ def fixed_arch(fixed_arch: Union[str, Path, Dict[str, Any]], verbose=True):
_logger.info(f'Fixed architecture: %s', fixed_arch)
return ContextStack('fixed', fixed_arch)
@contextmanager
def no_fixed_arch():
"""
Ignore the ``fixed_arch()`` context.
This is useful in creating a search space within a ``fixed_arch()`` context.
Under the hood, it only disables the most recent one fixed context, which means,
if it's currently in a nested with-fixed-arch context, multiple ``no_fixed_arch()`` contexts is required.
Examples
--------
>>> with fixed_arch(arch_dict):
... with no_fixed_arch():
... model_space = ModelSpace()
"""
NO_ARCH = '_no_arch_'
popped_arch = NO_ARCH # make linter happy
try:
try:
popped_arch = ContextStack.pop('fixed')
except IndexError:
# context unavailable
popped_arch = NO_ARCH
yield
finally:
if popped_arch is not NO_ARCH:
ContextStack.push('fixed', popped_arch)
......@@ -8,6 +8,7 @@ import torch.nn.functional as F
import nni.nas.nn.pytorch as nn
from nni.nas import model_wrapper, basic_unit
from nni.nas.fixed import no_fixed_arch
from nni.nas.nn.pytorch.choice import ValueChoiceX
from nni.nas.oneshot.pytorch.supermodule.operation import MixedOperation
from nni.nas.oneshot.pytorch.supermodule._valuechoice_utils import traverse_all_options
......@@ -432,7 +433,7 @@ class AutoformerSpace(nn.Module):
@classmethod
def load_strategy_checkpoint(cls, name: str, download: bool = True, progress: bool = True):
"""
Load the RandomOneShot strategy initialized with supernet weights.
Load the related strategy checkpoints.
Parameters
----------
......@@ -446,14 +447,17 @@ class AutoformerSpace(nn.Module):
Returns
-------
BaseStrategy
The RandomOneShot strategy initialized with supernet weights provided in the official repo.
The loaded strategy.
"""
legal = ['random-one-shot-tiny', 'random-one-shot-small', 'random-one-shot-base']
if name not in legal:
raise ValueError(f'Unsupported name: {name}. It should be one of {legal}.')
name = name[16:]
# RandomOneShot is the only supported strategy for now.
from nni.nas.strategy import RandomOneShot
init_kwargs = cls.preset(name)
with no_fixed_arch():
model_sapce = cls(**init_kwargs)
strategy = RandomOneShot(mutation_hooks=cls.get_extra_mutation_hooks())
strategy.attach_model(model_sapce)
......
......@@ -519,6 +519,12 @@ class BaseOneShotLightningModule(pl.LightningModule):
def on_train_end(self):
return self.model.on_train_end()
def on_validation_start(self):
return self.model.on_validation_start()
def on_validation_end(self):
return self.model.on_validation_end()
def on_fit_start(self):
return self.model.on_fit_start()
......
......@@ -61,6 +61,7 @@ class OneShotStrategy(BaseStrategy):
evaluator_module.running_mode = 'oneshot'
evaluator_module.set_model(py_model)
else:
# FIXME: this should be an evaluator + model
from nni.retiarii.evaluator.pytorch.lightning import ClassificationModule
evaluator_module = ClassificationModule()
evaluator_module.running_mode = 'oneshot'
......
......@@ -106,8 +106,8 @@ class ContextStack:
cls._stack[key].append(value)
@classmethod
def pop(cls, key: str) -> None:
cls._stack[key].pop()
def pop(cls, key: str) -> Any:
return cls._stack[key].pop()
@classmethod
def top(cls, key: str) -> Any:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment