Unverified Commit ac892fc7 authored by Yuge Zhang's avatar Yuge Zhang Committed by GitHub
Browse files

Model space hub enhancements (v2.9) (#5050)

parent 802650ff
...@@ -7,7 +7,6 @@ The implementation is based on NDS. ...@@ -7,7 +7,6 @@ The implementation is based on NDS.
It's called ``nasnet.py`` simply because NASNet is the first to propose such structure. It's called ``nasnet.py`` simply because NASNet is the first to propose such structure.
""" """
from collections import OrderedDict
from functools import partial from functools import partial
from typing import Tuple, List, Union, Iterable, Dict, Callable, Optional, cast from typing import Tuple, List, Union, Iterable, Dict, Callable, Optional, cast
...@@ -235,20 +234,6 @@ class AuxiliaryHead(nn.Module): ...@@ -235,20 +234,6 @@ class AuxiliaryHead(nn.Module):
return x return x
class SequentialBreakdown(nn.Sequential):
"""Return all layers of a sequential."""
def __init__(self, sequential: nn.Sequential):
super().__init__(OrderedDict(sequential.named_children()))
def forward(self, inputs):
result = []
for module in self:
inputs = module(inputs)
result.append(inputs)
return result
class CellPreprocessor(nn.Module): class CellPreprocessor(nn.Module):
""" """
Aligning the shape of predecessors. Aligning the shape of predecessors.
...@@ -296,7 +281,8 @@ class CellBuilder: ...@@ -296,7 +281,8 @@ class CellBuilder:
C: nn.MaybeChoice[int], C: nn.MaybeChoice[int],
num_nodes: int, num_nodes: int,
merge_op: Literal['all', 'loose_end'], merge_op: Literal['all', 'loose_end'],
first_cell_reduce: bool, last_cell_reduce: bool): first_cell_reduce: bool, last_cell_reduce: bool,
drop_path_prob: float):
self.C_prev_in = C_prev_in # This is the out channels of the cell before last cell. self.C_prev_in = C_prev_in # This is the out channels of the cell before last cell.
self.C_in = C_in # This is the out channesl of last cell. self.C_in = C_in # This is the out channesl of last cell.
self.C = C # This is NOT C_out of this stage, instead, C_out = C * len(cell.output_node_indices) self.C = C # This is NOT C_out of this stage, instead, C_out = C * len(cell.output_node_indices)
...@@ -305,6 +291,7 @@ class CellBuilder: ...@@ -305,6 +291,7 @@ class CellBuilder:
self.merge_op: Literal['all', 'loose_end'] = merge_op self.merge_op: Literal['all', 'loose_end'] = merge_op
self.first_cell_reduce = first_cell_reduce self.first_cell_reduce = first_cell_reduce
self.last_cell_reduce = last_cell_reduce self.last_cell_reduce = last_cell_reduce
self.drop_path_prob = drop_path_prob
self._expect_idx = 0 self._expect_idx = 0
# It takes an index that is the index in the repeat. # It takes an index that is the index in the repeat.
...@@ -318,11 +305,16 @@ class CellBuilder: ...@@ -318,11 +305,16 @@ class CellBuilder:
op: str, channels: int, is_reduction_cell: bool): op: str, channels: int, is_reduction_cell: bool):
if is_reduction_cell and ( if is_reduction_cell and (
input_index is None or input_index < self.num_predecessors input_index is None or input_index < self.num_predecessors
): # could be none when constructing search sapce ): # could be none when constructing search space
stride = 2 stride = 2
else: else:
stride = 1 stride = 1
return OPS[op](channels, stride, True) operation = OPS[op](channels, stride, True)
if self.drop_path_prob > 0 and not isinstance(operation, nn.Identity):
# Omit drop-path when operation is skip connect.
# https://github.com/quark0/darts/blob/f276dd346a09ae3160f8e3aca5c7b193fda1da37/cnn/model.py#L54
return nn.Sequential(operation, DropPath_(self.drop_path_prob))
return operation
def __call__(self, repeat_idx: int): def __call__(self, repeat_idx: int):
if self._expect_idx != repeat_idx: if self._expect_idx != repeat_idx:
...@@ -483,6 +475,8 @@ class NDS(nn.Module): ...@@ -483,6 +475,8 @@ class NDS(nn.Module):
See :class:`~nni.retiarii.nn.pytorch.Cell`. See :class:`~nni.retiarii.nn.pytorch.Cell`.
num_nodes_per_cell : int num_nodes_per_cell : int
See :class:`~nni.retiarii.nn.pytorch.Cell`. See :class:`~nni.retiarii.nn.pytorch.Cell`.
drop_path_prob : float
Apply drop path. Enabled when it's set to be greater than 0.
""" """
def __init__(self, def __init__(self,
...@@ -492,12 +486,14 @@ class NDS(nn.Module): ...@@ -492,12 +486,14 @@ class NDS(nn.Module):
width: Union[Tuple[int, ...], int] = 16, width: Union[Tuple[int, ...], int] = 16,
num_cells: Union[Tuple[int, ...], int] = 20, num_cells: Union[Tuple[int, ...], int] = 20,
dataset: Literal['cifar', 'imagenet'] = 'imagenet', dataset: Literal['cifar', 'imagenet'] = 'imagenet',
auxiliary_loss: bool = False): auxiliary_loss: bool = False,
drop_path_prob: float = 0.):
super().__init__() super().__init__()
self.dataset = dataset self.dataset = dataset
self.num_labels = 10 if dataset == 'cifar' else 1000 self.num_labels = 10 if dataset == 'cifar' else 1000
self.auxiliary_loss = auxiliary_loss self.auxiliary_loss = auxiliary_loss
self.drop_path_prob = drop_path_prob
# preprocess the specified width and depth # preprocess the specified width and depth
if isinstance(width, Iterable): if isinstance(width, Iterable):
...@@ -546,7 +542,7 @@ class NDS(nn.Module): ...@@ -546,7 +542,7 @@ class NDS(nn.Module):
# C_curr is number of channels for each operator in current stage. # C_curr is number of channels for each operator in current stage.
# C_out is usually `C * num_nodes_per_cell` because of concat operator. # C_out is usually `C * num_nodes_per_cell` because of concat operator.
cell_builder = CellBuilder(op_candidates, C_pprev, C_prev, C_curr, num_nodes_per_cell, cell_builder = CellBuilder(op_candidates, C_pprev, C_prev, C_curr, num_nodes_per_cell,
merge_op, stage_idx > 0, last_cell_reduce) merge_op, stage_idx > 0, last_cell_reduce, drop_path_prob)
stage: Union[NDSStage, nn.Sequential] = NDSStage(cell_builder, num_cells_per_stage[stage_idx]) stage: Union[NDSStage, nn.Sequential] = NDSStage(cell_builder, num_cells_per_stage[stage_idx])
if isinstance(stage, NDSStage): if isinstance(stage, NDSStage):
...@@ -581,7 +577,6 @@ class NDS(nn.Module): ...@@ -581,7 +577,6 @@ class NDS(nn.Module):
if auxiliary_loss: if auxiliary_loss:
assert isinstance(self.stages[2], nn.Sequential), 'Auxiliary loss can only be enabled in retrain mode.' assert isinstance(self.stages[2], nn.Sequential), 'Auxiliary loss can only be enabled in retrain mode.'
self.stages[2] = SequentialBreakdown(cast(nn.Sequential, self.stages[2]))
self.auxiliary_head = AuxiliaryHead(C_to_auxiliary, self.num_labels, dataset=self.dataset) # type: ignore self.auxiliary_head = AuxiliaryHead(C_to_auxiliary, self.num_labels, dataset=self.dataset) # type: ignore
self.global_pooling = nn.AdaptiveAvgPool2d((1, 1)) self.global_pooling = nn.AdaptiveAvgPool2d((1, 1))
...@@ -595,12 +590,13 @@ class NDS(nn.Module): ...@@ -595,12 +590,13 @@ class NDS(nn.Module):
s0 = s1 = self.stem(inputs) s0 = s1 = self.stem(inputs)
for stage_idx, stage in enumerate(self.stages): for stage_idx, stage in enumerate(self.stages):
if stage_idx == 2 and self.auxiliary_loss: if stage_idx == 2 and self.auxiliary_loss and self.training:
s = list(stage([s0, s1]).values()) assert isinstance(stage, nn.Sequential), 'Auxiliary loss is only supported for fixed architecture.'
s0, s1 = s[-1] for block_idx, block in enumerate(stage):
if self.training:
# auxiliary loss is attached to the first cell of the last stage. # auxiliary loss is attached to the first cell of the last stage.
logits_aux = self.auxiliary_head(s[0][1]) s0, s1 = block([s0, s1])
if block_idx == 0:
logits_aux = self.auxiliary_head(s1)
else: else:
s0, s1 = stage([s0, s1]) s0, s1 = stage([s0, s1])
...@@ -655,14 +651,16 @@ class NASNet(NDS): ...@@ -655,14 +651,16 @@ class NASNet(NDS):
width: Union[Tuple[int, ...], int] = (16, 24, 32), width: Union[Tuple[int, ...], int] = (16, 24, 32),
num_cells: Union[Tuple[int, ...], int] = (4, 8, 12, 16, 20), num_cells: Union[Tuple[int, ...], int] = (4, 8, 12, 16, 20),
dataset: Literal['cifar', 'imagenet'] = 'cifar', dataset: Literal['cifar', 'imagenet'] = 'cifar',
auxiliary_loss: bool = False): auxiliary_loss: bool = False,
drop_path_prob: float = 0.):
super().__init__(self.NASNET_OPS, super().__init__(self.NASNET_OPS,
merge_op='loose_end', merge_op='loose_end',
num_nodes_per_cell=5, num_nodes_per_cell=5,
width=width, width=width,
num_cells=num_cells, num_cells=num_cells,
dataset=dataset, dataset=dataset,
auxiliary_loss=auxiliary_loss) auxiliary_loss=auxiliary_loss,
drop_path_prob=drop_path_prob)
@model_wrapper @model_wrapper
...@@ -686,14 +684,16 @@ class ENAS(NDS): ...@@ -686,14 +684,16 @@ class ENAS(NDS):
width: Union[Tuple[int, ...], int] = (16, 24, 32), width: Union[Tuple[int, ...], int] = (16, 24, 32),
num_cells: Union[Tuple[int, ...], int] = (4, 8, 12, 16, 20), num_cells: Union[Tuple[int, ...], int] = (4, 8, 12, 16, 20),
dataset: Literal['cifar', 'imagenet'] = 'cifar', dataset: Literal['cifar', 'imagenet'] = 'cifar',
auxiliary_loss: bool = False): auxiliary_loss: bool = False,
drop_path_prob: float = 0.):
super().__init__(self.ENAS_OPS, super().__init__(self.ENAS_OPS,
merge_op='loose_end', merge_op='loose_end',
num_nodes_per_cell=5, num_nodes_per_cell=5,
width=width, width=width,
num_cells=num_cells, num_cells=num_cells,
dataset=dataset, dataset=dataset,
auxiliary_loss=auxiliary_loss) auxiliary_loss=auxiliary_loss,
drop_path_prob=drop_path_prob)
@model_wrapper @model_wrapper
...@@ -721,7 +721,8 @@ class AmoebaNet(NDS): ...@@ -721,7 +721,8 @@ class AmoebaNet(NDS):
width: Union[Tuple[int, ...], int] = (16, 24, 32), width: Union[Tuple[int, ...], int] = (16, 24, 32),
num_cells: Union[Tuple[int, ...], int] = (4, 8, 12, 16, 20), num_cells: Union[Tuple[int, ...], int] = (4, 8, 12, 16, 20),
dataset: Literal['cifar', 'imagenet'] = 'cifar', dataset: Literal['cifar', 'imagenet'] = 'cifar',
auxiliary_loss: bool = False): auxiliary_loss: bool = False,
drop_path_prob: float = 0.):
super().__init__(self.AMOEBA_OPS, super().__init__(self.AMOEBA_OPS,
merge_op='loose_end', merge_op='loose_end',
...@@ -729,7 +730,8 @@ class AmoebaNet(NDS): ...@@ -729,7 +730,8 @@ class AmoebaNet(NDS):
width=width, width=width,
num_cells=num_cells, num_cells=num_cells,
dataset=dataset, dataset=dataset,
auxiliary_loss=auxiliary_loss) auxiliary_loss=auxiliary_loss,
drop_path_prob=drop_path_prob)
@model_wrapper @model_wrapper
...@@ -757,14 +759,16 @@ class PNAS(NDS): ...@@ -757,14 +759,16 @@ class PNAS(NDS):
width: Union[Tuple[int, ...], int] = (16, 24, 32), width: Union[Tuple[int, ...], int] = (16, 24, 32),
num_cells: Union[Tuple[int, ...], int] = (4, 8, 12, 16, 20), num_cells: Union[Tuple[int, ...], int] = (4, 8, 12, 16, 20),
dataset: Literal['cifar', 'imagenet'] = 'cifar', dataset: Literal['cifar', 'imagenet'] = 'cifar',
auxiliary_loss: bool = False): auxiliary_loss: bool = False,
drop_path_prob: float = 0.):
super().__init__(self.PNAS_OPS, super().__init__(self.PNAS_OPS,
merge_op='all', merge_op='all',
num_nodes_per_cell=5, num_nodes_per_cell=5,
width=width, width=width,
num_cells=num_cells, num_cells=num_cells,
dataset=dataset, dataset=dataset,
auxiliary_loss=auxiliary_loss) auxiliary_loss=auxiliary_loss,
drop_path_prob=drop_path_prob)
@model_wrapper @model_wrapper
...@@ -774,10 +778,16 @@ class DARTS(NDS): ...@@ -774,10 +778,16 @@ class DARTS(NDS):
It is built upon :class:`~nni.retiarii.nn.pytorch.Cell`, and implemented based on :class:`~NDS`. It is built upon :class:`~nni.retiarii.nn.pytorch.Cell`, and implemented based on :class:`~NDS`.
Its operator candidates are :attribute:`~DARTS.DARTS_OPS`. Its operator candidates are :attribute:`~DARTS.DARTS_OPS`.
It has 4 nodes per cell, and the output is concatenation of all nodes in the cell. It has 4 nodes per cell, and the output is concatenation of all nodes in the cell.
.. note::
``none`` is not included in the operator candidates.
It has already been handled in the differentiable implementation of cell.
""" + _INIT_PARAMETER_DOCS """ + _INIT_PARAMETER_DOCS
DARTS_OPS = [ DARTS_OPS = [
'none', # 'none',
'max_pool_3x3', 'max_pool_3x3',
'avg_pool_3x3', 'avg_pool_3x3',
'skip_connect', 'skip_connect',
...@@ -791,14 +801,16 @@ class DARTS(NDS): ...@@ -791,14 +801,16 @@ class DARTS(NDS):
width: Union[Tuple[int, ...], int] = (16, 24, 32), width: Union[Tuple[int, ...], int] = (16, 24, 32),
num_cells: Union[Tuple[int, ...], int] = (4, 8, 12, 16, 20), num_cells: Union[Tuple[int, ...], int] = (4, 8, 12, 16, 20),
dataset: Literal['cifar', 'imagenet'] = 'cifar', dataset: Literal['cifar', 'imagenet'] = 'cifar',
auxiliary_loss: bool = False): auxiliary_loss: bool = False,
drop_path_prob: float = 0.):
super().__init__(self.DARTS_OPS, super().__init__(self.DARTS_OPS,
merge_op='all', merge_op='all',
num_nodes_per_cell=4, num_nodes_per_cell=4,
width=width, width=width,
num_cells=num_cells, num_cells=num_cells,
dataset=dataset, dataset=dataset,
auxiliary_loss=auxiliary_loss) auxiliary_loss=auxiliary_loss,
drop_path_prob=drop_path_prob)
@classmethod @classmethod
def load_searched_model( def load_searched_model(
......
...@@ -224,29 +224,29 @@ class ShuffleNetSpace(nn.Module): ...@@ -224,29 +224,29 @@ class ShuffleNetSpace(nn.Module):
for name, m in self.named_modules(): for name, m in self.named_modules():
if isinstance(m, nn.Conv2d): if isinstance(m, nn.Conv2d):
if 'first' in name: if 'first' in name:
torch.nn.init.normal_(m.weight, 0, 0.01) torch.nn.init.normal_(m.weight, 0, 0.01) # type: ignore
else: else:
torch.nn.init.normal_(m.weight, 0, 1.0 / m.weight.shape[1]) torch.nn.init.normal_(m.weight, 0, 1.0 / m.weight.shape[1]) # type: ignore
if m.bias is not None: if m.bias is not None:
torch.nn.init.constant_(m.bias, 0) torch.nn.init.constant_(m.bias, 0) # type: ignore
elif isinstance(m, nn.BatchNorm2d): elif isinstance(m, nn.BatchNorm2d):
if m.weight is not None: if m.weight is not None:
torch.nn.init.constant_(m.weight, 1) torch.nn.init.constant_(m.weight, 1) # type: ignore
if m.bias is not None: if m.bias is not None:
torch.nn.init.constant_(m.bias, 0.0001) torch.nn.init.constant_(m.bias, 0.0001) # type: ignore
if m.running_mean is not None: if m.running_mean is not None:
torch.nn.init.constant_(m.running_mean, 0) torch.nn.init.constant_(m.running_mean, 0) # type: ignore
elif isinstance(m, nn.BatchNorm1d): elif isinstance(m, nn.BatchNorm1d):
if m.weight is not None: if m.weight is not None:
torch.nn.init.constant_(m.weight, 1) torch.nn.init.constant_(m.weight, 1) # type: ignore
if m.bias is not None: if m.bias is not None:
torch.nn.init.constant_(m.bias, 0.0001) torch.nn.init.constant_(m.bias, 0.0001) # type: ignore
if m.running_mean is not None: if m.running_mean is not None:
torch.nn.init.constant_(m.running_mean, 0) torch.nn.init.constant_(m.running_mean, 0) # type: ignore
elif isinstance(m, nn.Linear): elif isinstance(m, nn.Linear):
torch.nn.init.normal_(m.weight, 0, 0.01) torch.nn.init.normal_(m.weight, 0, 0.01) # type: ignore
if m.bias is not None: if m.bias is not None:
torch.nn.init.constant_(m.bias, 0) torch.nn.init.constant_(m.bias, 0) # type: ignore
@classmethod @classmethod
def fixed_arch(cls, arch: dict) -> FixedFactory: def fixed_arch(cls, arch: dict) -> FixedFactory:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment