Unverified Commit ac892fc7 authored by Yuge Zhang's avatar Yuge Zhang Committed by GitHub
Browse files

Model space hub enhancements (v2.9) (#5050)

parent 802650ff
......@@ -7,7 +7,6 @@ The implementation is based on NDS.
It's called ``nasnet.py`` simply because NASNet is the first to propose such structure.
"""
from collections import OrderedDict
from functools import partial
from typing import Tuple, List, Union, Iterable, Dict, Callable, Optional, cast
......@@ -235,20 +234,6 @@ class AuxiliaryHead(nn.Module):
return x
class SequentialBreakdown(nn.Sequential):
"""Return all layers of a sequential."""
def __init__(self, sequential: nn.Sequential):
super().__init__(OrderedDict(sequential.named_children()))
def forward(self, inputs):
result = []
for module in self:
inputs = module(inputs)
result.append(inputs)
return result
class CellPreprocessor(nn.Module):
"""
Aligning the shape of predecessors.
......@@ -296,7 +281,8 @@ class CellBuilder:
C: nn.MaybeChoice[int],
num_nodes: int,
merge_op: Literal['all', 'loose_end'],
first_cell_reduce: bool, last_cell_reduce: bool):
first_cell_reduce: bool, last_cell_reduce: bool,
drop_path_prob: float):
self.C_prev_in = C_prev_in # This is the out channels of the cell before last cell.
self.C_in = C_in # This is the out channesl of last cell.
self.C = C # This is NOT C_out of this stage, instead, C_out = C * len(cell.output_node_indices)
......@@ -305,6 +291,7 @@ class CellBuilder:
self.merge_op: Literal['all', 'loose_end'] = merge_op
self.first_cell_reduce = first_cell_reduce
self.last_cell_reduce = last_cell_reduce
self.drop_path_prob = drop_path_prob
self._expect_idx = 0
# It takes an index that is the index in the repeat.
......@@ -318,11 +305,16 @@ class CellBuilder:
op: str, channels: int, is_reduction_cell: bool):
if is_reduction_cell and (
input_index is None or input_index < self.num_predecessors
): # could be none when constructing search sapce
): # could be none when constructing search space
stride = 2
else:
stride = 1
return OPS[op](channels, stride, True)
operation = OPS[op](channels, stride, True)
if self.drop_path_prob > 0 and not isinstance(operation, nn.Identity):
# Omit drop-path when operation is skip connect.
# https://github.com/quark0/darts/blob/f276dd346a09ae3160f8e3aca5c7b193fda1da37/cnn/model.py#L54
return nn.Sequential(operation, DropPath_(self.drop_path_prob))
return operation
def __call__(self, repeat_idx: int):
if self._expect_idx != repeat_idx:
......@@ -483,6 +475,8 @@ class NDS(nn.Module):
See :class:`~nni.retiarii.nn.pytorch.Cell`.
num_nodes_per_cell : int
See :class:`~nni.retiarii.nn.pytorch.Cell`.
drop_path_prob : float
Apply drop path. Enabled when it's set to be greater than 0.
"""
def __init__(self,
......@@ -492,12 +486,14 @@ class NDS(nn.Module):
width: Union[Tuple[int, ...], int] = 16,
num_cells: Union[Tuple[int, ...], int] = 20,
dataset: Literal['cifar', 'imagenet'] = 'imagenet',
auxiliary_loss: bool = False):
auxiliary_loss: bool = False,
drop_path_prob: float = 0.):
super().__init__()
self.dataset = dataset
self.num_labels = 10 if dataset == 'cifar' else 1000
self.auxiliary_loss = auxiliary_loss
self.drop_path_prob = drop_path_prob
# preprocess the specified width and depth
if isinstance(width, Iterable):
......@@ -546,7 +542,7 @@ class NDS(nn.Module):
# C_curr is number of channels for each operator in current stage.
# C_out is usually `C * num_nodes_per_cell` because of concat operator.
cell_builder = CellBuilder(op_candidates, C_pprev, C_prev, C_curr, num_nodes_per_cell,
merge_op, stage_idx > 0, last_cell_reduce)
merge_op, stage_idx > 0, last_cell_reduce, drop_path_prob)
stage: Union[NDSStage, nn.Sequential] = NDSStage(cell_builder, num_cells_per_stage[stage_idx])
if isinstance(stage, NDSStage):
......@@ -581,7 +577,6 @@ class NDS(nn.Module):
if auxiliary_loss:
assert isinstance(self.stages[2], nn.Sequential), 'Auxiliary loss can only be enabled in retrain mode.'
self.stages[2] = SequentialBreakdown(cast(nn.Sequential, self.stages[2]))
self.auxiliary_head = AuxiliaryHead(C_to_auxiliary, self.num_labels, dataset=self.dataset) # type: ignore
self.global_pooling = nn.AdaptiveAvgPool2d((1, 1))
......@@ -595,12 +590,13 @@ class NDS(nn.Module):
s0 = s1 = self.stem(inputs)
for stage_idx, stage in enumerate(self.stages):
if stage_idx == 2 and self.auxiliary_loss:
s = list(stage([s0, s1]).values())
s0, s1 = s[-1]
if self.training:
if stage_idx == 2 and self.auxiliary_loss and self.training:
assert isinstance(stage, nn.Sequential), 'Auxiliary loss is only supported for fixed architecture.'
for block_idx, block in enumerate(stage):
# auxiliary loss is attached to the first cell of the last stage.
logits_aux = self.auxiliary_head(s[0][1])
s0, s1 = block([s0, s1])
if block_idx == 0:
logits_aux = self.auxiliary_head(s1)
else:
s0, s1 = stage([s0, s1])
......@@ -655,14 +651,16 @@ class NASNet(NDS):
width: Union[Tuple[int, ...], int] = (16, 24, 32),
num_cells: Union[Tuple[int, ...], int] = (4, 8, 12, 16, 20),
dataset: Literal['cifar', 'imagenet'] = 'cifar',
auxiliary_loss: bool = False):
auxiliary_loss: bool = False,
drop_path_prob: float = 0.):
super().__init__(self.NASNET_OPS,
merge_op='loose_end',
num_nodes_per_cell=5,
width=width,
num_cells=num_cells,
dataset=dataset,
auxiliary_loss=auxiliary_loss)
auxiliary_loss=auxiliary_loss,
drop_path_prob=drop_path_prob)
@model_wrapper
......@@ -686,14 +684,16 @@ class ENAS(NDS):
width: Union[Tuple[int, ...], int] = (16, 24, 32),
num_cells: Union[Tuple[int, ...], int] = (4, 8, 12, 16, 20),
dataset: Literal['cifar', 'imagenet'] = 'cifar',
auxiliary_loss: bool = False):
auxiliary_loss: bool = False,
drop_path_prob: float = 0.):
super().__init__(self.ENAS_OPS,
merge_op='loose_end',
num_nodes_per_cell=5,
width=width,
num_cells=num_cells,
dataset=dataset,
auxiliary_loss=auxiliary_loss)
auxiliary_loss=auxiliary_loss,
drop_path_prob=drop_path_prob)
@model_wrapper
......@@ -721,7 +721,8 @@ class AmoebaNet(NDS):
width: Union[Tuple[int, ...], int] = (16, 24, 32),
num_cells: Union[Tuple[int, ...], int] = (4, 8, 12, 16, 20),
dataset: Literal['cifar', 'imagenet'] = 'cifar',
auxiliary_loss: bool = False):
auxiliary_loss: bool = False,
drop_path_prob: float = 0.):
super().__init__(self.AMOEBA_OPS,
merge_op='loose_end',
......@@ -729,7 +730,8 @@ class AmoebaNet(NDS):
width=width,
num_cells=num_cells,
dataset=dataset,
auxiliary_loss=auxiliary_loss)
auxiliary_loss=auxiliary_loss,
drop_path_prob=drop_path_prob)
@model_wrapper
......@@ -757,14 +759,16 @@ class PNAS(NDS):
width: Union[Tuple[int, ...], int] = (16, 24, 32),
num_cells: Union[Tuple[int, ...], int] = (4, 8, 12, 16, 20),
dataset: Literal['cifar', 'imagenet'] = 'cifar',
auxiliary_loss: bool = False):
auxiliary_loss: bool = False,
drop_path_prob: float = 0.):
super().__init__(self.PNAS_OPS,
merge_op='all',
num_nodes_per_cell=5,
width=width,
num_cells=num_cells,
dataset=dataset,
auxiliary_loss=auxiliary_loss)
auxiliary_loss=auxiliary_loss,
drop_path_prob=drop_path_prob)
@model_wrapper
......@@ -774,10 +778,16 @@ class DARTS(NDS):
It is built upon :class:`~nni.retiarii.nn.pytorch.Cell`, and implemented based on :class:`~NDS`.
Its operator candidates are :attribute:`~DARTS.DARTS_OPS`.
It has 4 nodes per cell, and the output is concatenation of all nodes in the cell.
.. note::
``none`` is not included in the operator candidates.
It has already been handled in the differentiable implementation of cell.
""" + _INIT_PARAMETER_DOCS
DARTS_OPS = [
'none',
# 'none',
'max_pool_3x3',
'avg_pool_3x3',
'skip_connect',
......@@ -791,14 +801,16 @@ class DARTS(NDS):
width: Union[Tuple[int, ...], int] = (16, 24, 32),
num_cells: Union[Tuple[int, ...], int] = (4, 8, 12, 16, 20),
dataset: Literal['cifar', 'imagenet'] = 'cifar',
auxiliary_loss: bool = False):
auxiliary_loss: bool = False,
drop_path_prob: float = 0.):
super().__init__(self.DARTS_OPS,
merge_op='all',
num_nodes_per_cell=4,
width=width,
num_cells=num_cells,
dataset=dataset,
auxiliary_loss=auxiliary_loss)
auxiliary_loss=auxiliary_loss,
drop_path_prob=drop_path_prob)
@classmethod
def load_searched_model(
......
......@@ -224,29 +224,29 @@ class ShuffleNetSpace(nn.Module):
for name, m in self.named_modules():
if isinstance(m, nn.Conv2d):
if 'first' in name:
torch.nn.init.normal_(m.weight, 0, 0.01)
torch.nn.init.normal_(m.weight, 0, 0.01) # type: ignore
else:
torch.nn.init.normal_(m.weight, 0, 1.0 / m.weight.shape[1])
torch.nn.init.normal_(m.weight, 0, 1.0 / m.weight.shape[1]) # type: ignore
if m.bias is not None:
torch.nn.init.constant_(m.bias, 0)
torch.nn.init.constant_(m.bias, 0) # type: ignore
elif isinstance(m, nn.BatchNorm2d):
if m.weight is not None:
torch.nn.init.constant_(m.weight, 1)
torch.nn.init.constant_(m.weight, 1) # type: ignore
if m.bias is not None:
torch.nn.init.constant_(m.bias, 0.0001)
torch.nn.init.constant_(m.bias, 0.0001) # type: ignore
if m.running_mean is not None:
torch.nn.init.constant_(m.running_mean, 0)
torch.nn.init.constant_(m.running_mean, 0) # type: ignore
elif isinstance(m, nn.BatchNorm1d):
if m.weight is not None:
torch.nn.init.constant_(m.weight, 1)
torch.nn.init.constant_(m.weight, 1) # type: ignore
if m.bias is not None:
torch.nn.init.constant_(m.bias, 0.0001)
torch.nn.init.constant_(m.bias, 0.0001) # type: ignore
if m.running_mean is not None:
torch.nn.init.constant_(m.running_mean, 0)
torch.nn.init.constant_(m.running_mean, 0) # type: ignore
elif isinstance(m, nn.Linear):
torch.nn.init.normal_(m.weight, 0, 0.01)
torch.nn.init.normal_(m.weight, 0, 0.01) # type: ignore
if m.bias is not None:
torch.nn.init.constant_(m.bias, 0)
torch.nn.init.constant_(m.bias, 0) # type: ignore
@classmethod
def fixed_arch(cls, arch: dict) -> FixedFactory:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment