Unverified Commit 481aa292 authored by Maze's avatar Maze Committed by GitHub
Browse files

Fix Autoformer to compatible with RandomOneShot strategy (#4987)

parent 5a3d82e8
# Copyright (c) Microsoft Corporation. # Copyright (c) Microsoft Corporation.
# Licensed under the MIT license. # Licensed under the MIT license.
import itertools from typing import Optional, Tuple, cast, Any, Dict
from typing import Optional, Tuple, cast
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from timm.models.layers import trunc_normal_, DropPath from timm.models.layers import trunc_normal_, DropPath
import nni.retiarii.nn.pytorch as nn import nni.retiarii.nn.pytorch as nn
from nni.retiarii import model_wrapper from nni.retiarii import model_wrapper, basic_unit
from nni.retiarii.nn.pytorch.api import ValueChoiceX
from nni.retiarii.oneshot.pytorch.supermodule.operation import MixedOperation
from nni.retiarii.oneshot.pytorch.supermodule._valuechoice_utils import traverse_all_options
from nni.retiarii.oneshot.pytorch.supermodule._operation_utils import Slicable as _S, MaybeWeighted as _W
from .utils.fixed import FixedFactory
from .utils.pretrained import load_pretrained_weight
class RelativePosition2D(nn.Module): class RelativePosition2D(nn.Module):
...@@ -16,10 +23,8 @@ class RelativePosition2D(nn.Module): ...@@ -16,10 +23,8 @@ class RelativePosition2D(nn.Module):
super().__init__() super().__init__()
self.head_embed_dim = head_embed_dim self.head_embed_dim = head_embed_dim
self.legnth = length self.legnth = length
self.embeddings_table_v = nn.Parameter( self.embeddings_table_v = nn.Parameter(torch.randn(length * 2 + 2, head_embed_dim))
torch.randn(length * 2 + 2, head_embed_dim)) self.embeddings_table_h = nn.Parameter(torch.randn(length * 2 + 2, head_embed_dim))
self.embeddings_table_h = nn.Parameter(
torch.randn(length * 2 + 2, head_embed_dim))
trunc_normal_(self.embeddings_table_v, std=.02) trunc_normal_(self.embeddings_table_v, std=.02)
trunc_normal_(self.embeddings_table_h, std=.02) trunc_normal_(self.embeddings_table_h, std=.02)
...@@ -28,48 +33,31 @@ class RelativePosition2D(nn.Module): ...@@ -28,48 +33,31 @@ class RelativePosition2D(nn.Module):
# remove the first cls token distance computation # remove the first cls token distance computation
length_q = length_q - 1 length_q = length_q - 1
length_k = length_k - 1 length_k = length_k - 1
range_vec_q = torch.arange(length_q) # init in the device directly, rather than move to device
range_vec_k = torch.arange(length_k) range_vec_q = torch.arange(length_q, device=self.embeddings_table_v.device)
range_vec_k = torch.arange(length_k, device=self.embeddings_table_v.device)
# compute the row and column distance # compute the row and column distance
distance_mat_v = (range_vec_k[None, :] // length_q_sqrt = int(length_q ** 0.5)
int(length_q ** 0.5) - distance_mat_v = (range_vec_k[None, :] // length_q_sqrt - range_vec_q[:, None] // length_q_sqrt)
range_vec_q[:, None] // distance_mat_h = (range_vec_k[None, :] % length_q_sqrt - range_vec_q[:, None] % length_q_sqrt)
int(length_q ** 0.5))
distance_mat_h = (range_vec_k[None, :] %
int(length_q ** 0.5) -
range_vec_q[:, None] %
int(length_q ** 0.5))
# clip the distance to the range of [-legnth, legnth] # clip the distance to the range of [-legnth, legnth]
distance_mat_clipped_v = torch.clamp( distance_mat_clipped_v = torch.clamp(distance_mat_v, - self.legnth, self.legnth)
distance_mat_v, -self.legnth, self.legnth) distance_mat_clipped_h = torch.clamp(distance_mat_h, - self.legnth, self.legnth)
distance_mat_clipped_h = torch.clamp(
distance_mat_h, -self.legnth, self.legnth)
# translate the distance from [1, 2 * legnth + 1], 0 is for the cls # translate the distance from [1, 2 * legnth + 1], 0 is for the cls token
# token
final_mat_v = distance_mat_clipped_v + self.legnth + 1 final_mat_v = distance_mat_clipped_v + self.legnth + 1
final_mat_h = distance_mat_clipped_h + self.legnth + 1 final_mat_h = distance_mat_clipped_h + self.legnth + 1
# pad the 0 which represent the cls token # pad the 0 which represent the cls token
final_mat_v = F.pad( final_mat_v = F.pad(final_mat_v, (1, 0, 1, 0), "constant", 0)
final_mat_v, (1, 0, 1, 0), "constant", 0) final_mat_h = F.pad(final_mat_h, (1, 0, 1, 0), "constant", 0)
final_mat_h = F.pad(
final_mat_h, (1, 0, 1, 0), "constant", 0) final_mat_v = final_mat_v.long()
final_mat_h = final_mat_h.long()
final_mat_v = torch.tensor(
final_mat_v,
dtype=torch.long,
device=self.embeddings_table_v.device)
final_mat_h = torch.tensor(
final_mat_h,
dtype=torch.long,
device=self.embeddings_table_v.device)
# get the embeddings with the corresponding distance # get the embeddings with the corresponding distance
embeddings = self.embeddings_table_v[final_mat_v] + \ embeddings = self.embeddings_table_v[final_mat_v] + self.embeddings_table_h[final_mat_h]
self.embeddings_table_h[final_mat_h]
return embeddings return embeddings
class RelativePositionAttention(nn.Module): class RelativePositionAttention(nn.Module):
""" """
This class is designed to support the relative position in attention. This class is designed to support the relative position in attention.
...@@ -80,61 +68,62 @@ class RelativePositionAttention(nn.Module): ...@@ -80,61 +68,62 @@ class RelativePositionAttention(nn.Module):
and keys in self-attention modules. and keys in self-attention modules.
""" """
def __init__( def __init__(
self, self, embed_dim, num_heads,
embed_dim, attn_drop=0., proj_drop=0.,
fixed_embed_dim, qkv_bias=False, qk_scale=None,
num_heads, rpe_length=14, rpe=False,
attn_drop=0., head_dim=64):
proj_drop=0,
rpe=False,
qkv_bias=False,
qk_scale=None,
rpe_length=14) -> None:
super().__init__() super().__init__()
self.num_heads = num_heads self.num_heads = num_heads
head_dim = embed_dim // num_heads # head_dim is fixed 64 in official autoformer. set head_dim = None to use flex head dim.
self.head_dim = head_dim or (embed_dim // num_heads)
self.scale = qk_scale or head_dim ** -0.5 self.scale = qk_scale or head_dim ** -0.5
self.qkv = nn.Linear(embed_dim, embed_dim * 3, bias=qkv_bias)
# Please refer to MixedMultiheadAttention for details.
self.q = nn.Linear(embed_dim, head_dim * num_heads, bias = qkv_bias)
self.k = nn.Linear(embed_dim, head_dim * num_heads, bias = qkv_bias)
self.v = nn.Linear(embed_dim, head_dim * num_heads, bias = qkv_bias)
self.attn_drop = nn.Dropout(attn_drop) self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(embed_dim, embed_dim) self.proj = nn.Linear(head_dim * num_heads, embed_dim)
self.proj_drop = nn.Dropout(proj_drop) self.proj_drop = nn.Dropout(proj_drop)
self.rpe = rpe self.rpe = rpe
if rpe: if rpe:
self.rel_pos_embed_k = RelativePosition2D( self.rel_pos_embed_k = RelativePosition2D(head_dim, rpe_length)
fixed_embed_dim // num_heads, rpe_length) self.rel_pos_embed_v = RelativePosition2D(head_dim, rpe_length)
self.rel_pos_embed_v = RelativePosition2D(
fixed_embed_dim // num_heads, rpe_length)
def forward(self, x): def forward(self, x):
B, N, C = x.shape B, N, _ = x.shape
qkv = self.qkv(x).reshape(B,N,3,self.num_heads,C//self.num_heads).permute( head_dim = self.head_dim
2,0,3,1,4) # num_heads can not get from self.num_heads directly,
# make torchscript happy (cannot use tensor as tuple) # use -1 to compute implicitly.
q, k, v = qkv[0], qkv[1], qkv[2] num_heads = -1
q = self.q(x).reshape(B, N, num_heads, head_dim).permute(0, 2, 1, 3)
k = self.k(x).reshape(B, N, num_heads, head_dim).permute(0, 2, 1, 3)
v = self.v(x).reshape(B, N, num_heads, head_dim).permute(0, 2, 1, 3)
num_heads = q.size(1)
attn = (q @ k.transpose(-2, -1)) * self.scale attn = (q @ k.transpose(-2, -1)) * self.scale
if self.rpe: if self.rpe:
r_p_k = self.rel_pos_embed_k(N, N) r_p_k = self.rel_pos_embed_k(N, N)
attn = attn + ( attn = attn + (
q.permute(2, 0, 1, 3).reshape( q.permute(2, 0, 1, 3).reshape(N, num_heads * B, head_dim) @ r_p_k.transpose(2, 1)
N, self.num_heads * B, -1) @ r_p_k.transpose( ).transpose(1, 0).reshape(B, num_heads, N, N) * self.scale
2, 1)) .transpose(1, 0).reshape(
B, self.num_heads, N, N) * self.scale
attn = attn.softmax(dim=-1) attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn) attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, C) x = (attn @ v).transpose(1, 2).reshape(B, N, num_heads * head_dim)
if self.rpe: if self.rpe:
attn_1 = attn.permute(2, 0, 1, 3).reshape(N, B * num_heads, N)
r_p_v = self.rel_pos_embed_v(N, N) r_p_v = self.rel_pos_embed_v(N, N)
attn_1 = attn.permute(
2, 0, 1, 3).reshape(
N, B * self.num_heads, -1)
# The size of attention is (B, num_heads, N, N), reshape it to (N, B*num_heads, N) and do batch matmul with # The size of attention is (B, num_heads, N, N), reshape it to (N, B*num_heads, N) and do batch matmul with
# the relative position embedding of V (N, N, head_dim) get shape like (N, B*num_heads, head_dim). We reshape it to the # the relative position embedding of V (N, N, head_dim) get shape like (N, B*num_heads, head_dim). We reshape it to the
# same size as x (B, num_heads, N, hidden_dim) # same size as x (B, num_heads, N, hidden_dim)
x = x + (attn_1 @ r_p_v).transpose(1, 0).reshape(B, x = x + (attn_1 @ r_p_v).transpose(1, 0).reshape(B, num_heads, N, head_dim).transpose(2, 1).reshape(B, N, num_heads * head_dim)
self.num_heads, N, -1).transpose(2, 1).reshape(B, N, -1)
x = self.proj(x) x = self.proj(x)
x = self.proj_drop(x) x = self.proj_drop(x)
return x return x
...@@ -146,61 +135,60 @@ class TransformerEncoderLayer(nn.Module): ...@@ -146,61 +135,60 @@ class TransformerEncoderLayer(nn.Module):
The pytorch build-in nn.TransformerEncoderLayer() does not support customed attention. The pytorch build-in nn.TransformerEncoderLayer() does not support customed attention.
""" """
def __init__( def __init__(
self, self, embed_dim, num_heads, mlp_ratio=4.,
embed_dim, qkv_bias=False, qk_scale=None, rpe=False,
fixed_embed_dim, drop_rate=0., attn_drop=0., proj_drop=0., drop_path=0.,
num_heads, pre_norm=True, rpe_length=14, head_dim=64
mlp_ratio=4.,
qkv_bias=False,
qk_scale=None,
rpe=False,
drop_rate=0.,
attn_drop=0.,
proj_drop=0.,
drop_path=0.,
pre_norm=True,
rpe_length=14,
): ):
super().__init__() super().__init__()
self.normalize_before = pre_norm self.normalize_before = pre_norm
self.drop_path = DropPath( self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
drop_path) if drop_path > 0. else nn.Identity()
self.dropout = drop_rate self.dropout = drop_rate
self.attn = RelativePositionAttention( self.attn = RelativePositionAttention(
embed_dim=embed_dim, embed_dim=embed_dim,
fixed_embed_dim=fixed_embed_dim,
num_heads=num_heads, num_heads=num_heads,
attn_drop=attn_drop, attn_drop=attn_drop,
proj_drop=proj_drop, proj_drop=proj_drop,
rpe=rpe, rpe=rpe,
qkv_bias=qkv_bias, qkv_bias=qkv_bias,
qk_scale=qk_scale, qk_scale=qk_scale,
rpe_length=rpe_length) rpe_length=rpe_length,
head_dim=head_dim
)
self.attn_layer_norm = nn.LayerNorm(embed_dim) self.attn_layer_norm = nn.LayerNorm(embed_dim)
self.ffn_layer_norm = nn.LayerNorm(embed_dim) self.ffn_layer_norm = nn.LayerNorm(embed_dim)
self.activation_fn = nn.GELU() self.activation_fn = nn.GELU()
self.fc1 = nn.Linear( self.fc1 = nn.Linear(
cast(int, embed_dim), cast(int, nn.ValueChoice.to_int( cast(int, embed_dim),
embed_dim * mlp_ratio))) cast(int, nn.ValueChoice.to_int(embed_dim * mlp_ratio))
)
self.fc2 = nn.Linear( self.fc2 = nn.Linear(
cast(int, nn.ValueChoice.to_int( cast(int, nn.ValueChoice.to_int(embed_dim * mlp_ratio)),
embed_dim * mlp_ratio)), cast(int, embed_dim)
cast(int, embed_dim)) )
def maybe_layer_norm(self, layer_norm, x, before=False, after=False):
assert before ^ after
if after ^ self.normalize_before:
return layer_norm(x)
else:
return x
def forward(self, x): def forward(self, x):
""" """
Args: Args:
x (Tensor): input to the layer of shape `(batch, patch_num , sample_embed_dim)` x (Tensor): input to the layer of shape `(batch, patch_num , sample_embed_dim)`
Returns: Returns:
encoded output of shape `(batch, patch_num, sample_embed_dim)` encoded output of shape `(batch, patch_num, sample_embed_dim)`
""" """
residual = x residual = x
x = self.maybe_layer_norm(self.attn_layer_norm, x, before=True) x = self.maybe_layer_norm(self.attn_layer_norm, x, before=True)
x = self.attn(x) x = self.attn(x)
x = nn.functional.dropout(x, p=self.dropout, training=self.training) x = F.dropout(x, p=self.dropout, training=self.training)
x = self.drop_path(x) x = self.drop_path(x)
x = residual + x x = residual + x
x = self.maybe_layer_norm(self.attn_layer_norm, x, after=True) x = self.maybe_layer_norm(self.attn_layer_norm, x, after=True)
...@@ -209,20 +197,86 @@ class TransformerEncoderLayer(nn.Module): ...@@ -209,20 +197,86 @@ class TransformerEncoderLayer(nn.Module):
x = self.maybe_layer_norm(self.ffn_layer_norm, x, before=True) x = self.maybe_layer_norm(self.ffn_layer_norm, x, before=True)
x = self.fc1(x) x = self.fc1(x)
x = self.activation_fn(x) x = self.activation_fn(x)
x = nn.functional.dropout(x, p=self.dropout, training=self.training) x = F.dropout(x, p=self.dropout, training=self.training)
x = self.fc2(x) x = self.fc2(x)
x = nn.functional.dropout(x, p=self.dropout, training=self.training) x = F.dropout(x, p=self.dropout, training=self.training)
x = self.drop_path(x) x = self.drop_path(x)
x = residual + x x = residual + x
x = self.maybe_layer_norm(self.ffn_layer_norm, x, after=True) x = self.maybe_layer_norm(self.ffn_layer_norm, x, after=True)
return x return x
def maybe_layer_norm(self, layer_norm, x, before=False, after=False):
assert before ^ after @basic_unit
if after ^ self.normalize_before: class ClsToken(nn.Module):
return layer_norm(x) """ Concat class token with dim=embed_dim before patch embedding.
else: """
return x def __init__(self, embed_dim: int):
super().__init__()
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
trunc_normal_(self.cls_token, std=.02)
def forward(self, x):
return torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
class MixedClsToken(MixedOperation, ClsToken):
""" Mixed class token concat operation.
Supported arguments are:
- ``embed_dim``
Prefix of cls_token will be sliced.
"""
bound_type = ClsToken
argument_list = ['embed_dim']
def super_init_argument(self, name: str, value_choice: ValueChoiceX):
return max(traverse_all_options(value_choice))
def forward_with_args(self, embed_dim,
inputs: torch.Tensor) -> torch.Tensor:
embed_dim_ = _W(embed_dim)
cls_token = _S(self.cls_token)[..., :embed_dim_]
return torch.cat((cls_token.expand(inputs.shape[0], -1, -1), inputs), dim=1)
@basic_unit
class AbsPosEmbed(nn.Module):
""" Add absolute position embedding on patch embedding.
"""
def __init__(self, length: int, embed_dim: int):
super().__init__()
self.pos_embed = nn.Parameter(torch.zeros(1, length, embed_dim))
trunc_normal_(self.pos_embed, std=.02)
def forward(self, x):
return x + self.pos_embed
class MixedAbsPosEmbed(MixedOperation, AbsPosEmbed):
""" Mixed absolute position embedding add operation.
Supported arguments are:
- ``embed_dim``
Prefix of pos_embed will be sliced.
"""
bound_type = AbsPosEmbed
argument_list = ['embed_dim']
def super_init_argument(self, name: str, value_choice: ValueChoiceX):
return max(traverse_all_options(value_choice))
def forward_with_args(self, embed_dim,
inputs: torch.Tensor) -> torch.Tensor:
embed_dim_ = _W(embed_dim)
pos_embed = _S(self.pos_embed)[..., :embed_dim_]
return inputs + pos_embed
@model_wrapper @model_wrapper
...@@ -267,87 +321,144 @@ class AutoformerSpace(nn.Module): ...@@ -267,87 +321,144 @@ class AutoformerSpace(nn.Module):
The scaler on score map in self-attention. The scaler on score map in self-attention.
rpe : bool rpe : bool
Whether to use relative position encoding. Whether to use relative position encoding.
""" """
def __init__( def __init__(
self, self,
search_embed_dim: Tuple[int, ...] = (192, 216, 240), search_embed_dim: Tuple[int, ...] = (192, 216, 240),
search_mlp_ratio: Tuple[float, ...] = (3.5, 4.0), search_mlp_ratio: Tuple[float, ...] = (3.0, 3.5, 4.0),
search_num_heads: Tuple[int, ...] = (3, 4), search_num_heads: Tuple[int, ...] = (3, 4),
search_depth: Tuple[int, ...] = (12, 13, 14), search_depth: Tuple[int, ...] = (12, 13, 14),
img_size: int = 224, img_size: int = 224,
patch_size: int = 16, patch_size: int = 16,
in_chans: int = 3, in_chans: int = 3,
num_classes: int = 1000, num_classes: int = 1000,
qkv_bias: bool = False, qkv_bias: bool = False,
drop_rate: float = 0., drop_rate: float = 0.,
attn_drop_rate: float = 0., attn_drop_rate: float = 0.,
drop_path_rate: float = 0., drop_path_rate: float = 0.,
pre_norm: bool = True, pre_norm: bool = True,
global_pool: bool = False, global_pool: bool = False,
abs_pos: bool = True, abs_pos: bool = True,
qk_scale: Optional[float] = None, qk_scale: Optional[float] = None,
rpe: bool = True, rpe: bool = True,
): ):
super().__init__() super().__init__()
# define search space parameters
embed_dim = nn.ValueChoice(list(search_embed_dim), label="embed_dim") embed_dim = nn.ValueChoice(list(search_embed_dim), label="embed_dim")
fixed_embed_dim = nn.ModelParameterChoice(
list(search_embed_dim), label="embed_dim")
depth = nn.ValueChoice(list(search_depth), label="depth") depth = nn.ValueChoice(list(search_depth), label="depth")
mlp_ratios = [nn.ValueChoice(list(search_mlp_ratio), label=f"mlp_ratio_{i}") for i in range(max(search_depth))]
num_heads = [nn.ValueChoice(list(search_num_heads), label=f"num_head_{i}") for i in range(max(search_depth))]
self.patch_embed = nn.Conv2d( self.patch_embed = nn.Conv2d(
in_chans, in_chans, cast(int, embed_dim),
cast(int, embed_dim), kernel_size = patch_size,
kernel_size=patch_size, stride = patch_size
stride=patch_size) )
self.patches_num = int((img_size // patch_size) ** 2) self.patches_num = int((img_size // patch_size) ** 2)
self.global_pool = global_pool self.global_pool = global_pool
self.cls_token = nn.Parameter(torch.zeros(1, 1, cast(int, fixed_embed_dim)))
trunc_normal_(self.cls_token, std=.02)
dpr = [ self.cls_token = ClsToken(cast(int, embed_dim))
x.item() for x in torch.linspace( self.pos_embed = AbsPosEmbed(self.patches_num+1, cast(int, embed_dim)) if abs_pos else nn.Identity()
0,
drop_path_rate, dpr = [x.item() for x in torch.linspace(0, drop_path_rate, max(search_depth))] # stochastic depth decay rule
max(search_depth))] # stochastic depth decay rule
self.blocks = nn.Repeat(
self.abs_pos = abs_pos lambda index: TransformerEncoderLayer(
if self.abs_pos: embed_dim = embed_dim, num_heads = num_heads[index], mlp_ratio=mlp_ratios[index],
self.pos_embed = nn.Parameter(torch.zeros( qkv_bias = qkv_bias, drop_rate = drop_rate, attn_drop = attn_drop_rate, drop_path=dpr[index],
1, self.patches_num + 1, cast(int, fixed_embed_dim))) rpe_length=img_size // patch_size, qk_scale=qk_scale, rpe=rpe, pre_norm=pre_norm, head_dim = 64
trunc_normal_(self.pos_embed, std=.02) ), depth
)
self.blocks = nn.Repeat(lambda index: nn.LayerChoice([
TransformerEncoderLayer(embed_dim=embed_dim, self.norm = nn.LayerNorm(cast(int, embed_dim)) if pre_norm else nn.Identity()
fixed_embed_dim=fixed_embed_dim, self.head = nn.Linear(cast(int, embed_dim), num_classes) if num_classes > 0 else nn.Identity()
num_heads=num_heads, mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias, drop_rate=drop_rate, @classmethod
attn_drop=attn_drop_rate, def get_extra_mutation_hooks(cls):
drop_path=dpr[index], return [MixedAbsPosEmbed.mutate, MixedClsToken.mutate]
rpe_length=img_size // patch_size,
qk_scale=qk_scale, rpe=rpe, @classmethod
pre_norm=pre_norm,) def load_searched_model(
for mlp_ratio, num_heads in itertools.product(search_mlp_ratio, search_num_heads) cls, name: str,
], label=f'layer{index}'), depth) pretrained: bool = False, download: bool = False, progress: bool = True
self.pre_norm = pre_norm ) -> nn.Module:
if self.pre_norm:
self.norm = nn.LayerNorm(cast(int, embed_dim)) init_kwargs = {'qkv_bias': True, 'drop_rate': 0.0, 'drop_path_rate': 0.1, 'global_pool': True, 'num_classes': 1000}
self.head = nn.Linear( if name == 'autoformer-tiny':
cast(int, embed_dim), mlp_ratio = [3.5, 3.5, 3.0, 3.5, 3.0, 3.0, 4.0, 4.0, 3.5, 4.0, 3.5, 4.0, 3.5] + [3.0]
num_classes) if num_classes > 0 else nn.Identity() num_head = [3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 4, 3, 3] + [3]
arch: Dict[str, Any] = {
'embed_dim': 192,
'depth': 13
}
for i in range(14):
arch[f'mlp_ratio_{i}'] = mlp_ratio[i]
arch[f'num_head_{i}'] = num_head[i]
init_kwargs.update({
'search_embed_dim': (240, 216, 192),
'search_mlp_ratio': (4.0, 3.5, 3.0),
'search_num_heads': (4, 3),
'search_depth': (14, 13, 12),
})
elif name == 'autoformer-small':
mlp_ratio = [3.0, 3.5, 3.0, 3.5, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 3.5, 4.0] + [3.0]
num_head = [6, 6, 5, 7, 5, 5, 5, 6, 6, 7, 7, 6, 7] + [5]
arch: Dict[str, Any] = {
'embed_dim': 384,
'depth': 13
}
for i in range(14):
arch[f'mlp_ratio_{i}'] = mlp_ratio[i]
arch[f'num_head_{i}'] = num_head[i]
init_kwargs.update({
'search_embed_dim': (448, 384, 320),
'search_mlp_ratio': (4.0, 3.5, 3.0),
'search_num_heads': (7, 6, 5),
'search_depth': (14, 13, 12),
})
elif name == 'autoformer-base':
mlp_ratio = [3.5, 3.5, 4.0, 3.5, 4.0, 3.5, 3.5, 3.0, 4.0, 4.0, 3.0, 4.0, 3.0, 3.5] + [3.0, 3.0]
num_head = [9, 9, 9, 9, 9, 10, 9, 9, 10, 9, 10, 9, 9, 10] + [8, 8]
arch: Dict[str, Any] = {
'embed_dim': 576,
'depth': 14
}
for i in range(16):
arch[f'mlp_ratio_{i}'] = mlp_ratio[i]
arch[f'num_head_{i}'] = num_head[i]
init_kwargs.update({
'search_embed_dim': (624, 576, 528),
'search_mlp_ratio': (4.0, 3.5, 3.0),
'search_num_heads': (10, 9, 8),
'search_depth': (16, 15, 14),
})
else:
raise ValueError(f'Unsupported architecture with name: {name}')
model_factory = FixedFactory(cls, arch)
model = model_factory(**init_kwargs)
if pretrained:
weight_file = load_pretrained_weight(name, download=download, progress=progress)
pretrained_weights = torch.load(weight_file)
model.load_state_dict(pretrained_weights)
return model
def forward(self, x): def forward(self, x):
B = x.shape[0] B = x.shape[0]
x = self.patch_embed(x) x = self.patch_embed(x)
x = x.permute(0, 2, 3, 1).view(B, self.patches_num, -1) x = x.permute(0, 2, 3, 1).view(B, self.patches_num, -1)
cls_tokens = self.cls_token.expand(B, -1, -1) x = self.cls_token(x)
x = torch.cat((cls_tokens, x), dim=1) x = self.pos_embed(x)
if self.abs_pos:
x = x + self.pos_embed
x = self.blocks(x) x = self.blocks(x)
if self.pre_norm: x = self.norm(x)
x = self.norm(x)
if self.global_pool: if self.global_pool:
x = torch.mean(x[:, 1:], dim=1) x = torch.mean(x[:, 1:], dim=1)
else: else:
......
...@@ -37,6 +37,11 @@ PRETRAINED_WEIGHT_URLS = { ...@@ -37,6 +37,11 @@ PRETRAINED_WEIGHT_URLS = {
# spos # spos
'spos': f'{NNI_BLOB}/nashub/spos-0b17f6fc.pth', 'spos': f'{NNI_BLOB}/nashub/spos-0b17f6fc.pth',
# autoformer
'autoformer-tiny': f'{NNI_BLOB}/nashub/autoformer-searched-tiny-1e90ebc1.pth',
'autoformer-small': f'{NNI_BLOB}/nashub/autoformer-searched-small-4bc5d4e5.pth',
'autoformer-base': f'{NNI_BLOB}/nashub/autoformer-searched-base-c417590a.pth'
} }
......
...@@ -140,7 +140,7 @@ class Slicable(Generic[T]): ...@@ -140,7 +140,7 @@ class Slicable(Generic[T]):
raise TypeError(f'Unsuppoted weight type: {type(weight)}') raise TypeError(f'Unsuppoted weight type: {type(weight)}')
self.weight = weight self.weight = weight
def __getitem__(self, index: slice_type | multidim_slice) -> T: def __getitem__(self, index: slice_type | multidim_slice | Any) -> T:
if not isinstance(index, tuple): if not isinstance(index, tuple):
index = (index, ) index = (index, )
index = cast(multidim_slice, index) index = cast(multidim_slice, index)
...@@ -267,7 +267,7 @@ def _iterate_over_slice_type(s: slice_type): ...@@ -267,7 +267,7 @@ def _iterate_over_slice_type(s: slice_type):
def _iterate_over_multidim_slice(ms: multidim_slice): def _iterate_over_multidim_slice(ms: multidim_slice):
"""Get :class:`MaybeWeighted` instances in ``ms``.""" """Get :class:`MaybeWeighted` instances in ``ms``."""
for s in ms: for s in ms:
if s is not None: if s is not None and s is not Ellipsis:
yield from _iterate_over_slice_type(s) yield from _iterate_over_slice_type(s)
...@@ -286,8 +286,8 @@ def _evaluate_multidim_slice(ms: multidim_slice, value_fn: _value_fn_type = None ...@@ -286,8 +286,8 @@ def _evaluate_multidim_slice(ms: multidim_slice, value_fn: _value_fn_type = None
"""Wraps :meth:`MaybeWeighted.evaluate` to evaluate the whole ``multidim_slice``.""" """Wraps :meth:`MaybeWeighted.evaluate` to evaluate the whole ``multidim_slice``."""
res = [] res = []
for s in ms: for s in ms:
if s is not None: if s is not None and s is not Ellipsis:
res.append(_evaluate_slice_type(s, value_fn)) res.append(_evaluate_slice_type(s, value_fn))
else: else:
res.append(None) res.append(s)
return tuple(res) return tuple(res)
...@@ -35,6 +35,7 @@ __all__ = [ ...@@ -35,6 +35,7 @@ __all__ = [
'MixedLinear', 'MixedLinear',
'MixedConv2d', 'MixedConv2d',
'MixedBatchNorm2d', 'MixedBatchNorm2d',
'MixedLayerNorm',
'MixedMultiHeadAttention', 'MixedMultiHeadAttention',
'NATIVE_MIXED_OPERATIONS', 'NATIVE_MIXED_OPERATIONS',
] ]
...@@ -472,6 +473,74 @@ class MixedBatchNorm2d(MixedOperation, nn.BatchNorm2d): ...@@ -472,6 +473,74 @@ class MixedBatchNorm2d(MixedOperation, nn.BatchNorm2d):
eps, eps,
) )
class MixedLayerNorm(MixedOperation, nn.LayerNorm):
"""
Mixed LayerNorm operation.
Supported arguments are:
- ``normalized_shape``
- ``eps`` (only supported in path sampling)
For path-sampling, prefix of ``weight`` and ``bias`` are sliced.
For weighted cases, the maximum ``normalized_shape`` is used directly.
eps is required to be float.
"""
bound_type = retiarii_nn.LayerNorm
argument_list = ['normalized_shape', 'eps']
@staticmethod
def _to_tuple(value: scalar_or_scalar_dict[Any]) -> tuple[Any, Any]:
if not isinstance(value, tuple):
return (value, value)
return value
def super_init_argument(self, name: str, value_choice: ValueChoiceX):
if name not in ['normalized_shape']:
raise NotImplementedError(f'Unsupported value choice on argument: {name}')
all_sizes = set(traverse_all_options(value_choice))
if any(isinstance(sz, (tuple, list)) for sz in all_sizes):
# transpose
all_sizes = list(zip(*all_sizes))
# maximum dim should be calculated on every dimension
return (max(self._to_tuple(sz)) for sz in all_sizes)
else:
return max(all_sizes)
def forward_with_args(self,
normalized_shape,
eps: float,
inputs: torch.Tensor) -> torch.Tensor:
if any(isinstance(arg, dict) for arg in [eps]):
raise ValueError(_diff_not_compatible_error.format('eps', 'LayerNorm'))
if isinstance(normalized_shape, dict):
normalized_shape = self.normalized_shape
# make it as tuple
if isinstance(normalized_shape, int):
normalized_shape = (normalized_shape, )
if isinstance(self.normalized_shape, int):
normalized_shape = (self.normalized_shape, )
# slice all the normalized shape
indices = [slice(0, min(i, j)) for i, j in zip(normalized_shape, self.normalized_shape)]
# remove _S(*)
weight = self.weight[indices] if self.weight is not None else None
bias = self.bias[indices] if self.bias is not None else None
return F.layer_norm(
inputs,
normalized_shape,
weight,
bias,
eps
)
class MixedMultiHeadAttention(MixedOperation, nn.MultiheadAttention): class MixedMultiHeadAttention(MixedOperation, nn.MultiheadAttention):
""" """
...@@ -628,6 +697,7 @@ NATIVE_MIXED_OPERATIONS: list[Type[MixedOperation]] = [ ...@@ -628,6 +697,7 @@ NATIVE_MIXED_OPERATIONS: list[Type[MixedOperation]] = [
MixedLinear, MixedLinear,
MixedConv2d, MixedConv2d,
MixedBatchNorm2d, MixedBatchNorm2d,
MixedLayerNorm,
MixedMultiHeadAttention, MixedMultiHeadAttention,
] ]
......
...@@ -3,7 +3,7 @@ import pytest ...@@ -3,7 +3,7 @@ import pytest
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
from nni.retiarii.nn.pytorch import ValueChoice, Conv2d, BatchNorm2d, Linear, MultiheadAttention from nni.retiarii.nn.pytorch import ValueChoice, Conv2d, BatchNorm2d, LayerNorm, Linear, MultiheadAttention
from nni.retiarii.oneshot.pytorch.base_lightning import traverse_and_mutate_submodules from nni.retiarii.oneshot.pytorch.base_lightning import traverse_and_mutate_submodules
from nni.retiarii.oneshot.pytorch.supermodule.differentiable import ( from nni.retiarii.oneshot.pytorch.supermodule.differentiable import (
MixedOpDifferentiablePolicy, DifferentiableMixedLayer, DifferentiableMixedInput, GumbelSoftmax, MixedOpDifferentiablePolicy, DifferentiableMixedLayer, DifferentiableMixedInput, GumbelSoftmax,
...@@ -28,6 +28,12 @@ def test_slice(): ...@@ -28,6 +28,12 @@ def test_slice():
assert S(weight)[:, 1:W(3)*2+1, :, 9:13].shape == (3, 6, 24, 4) assert S(weight)[:, 1:W(3)*2+1, :, 9:13].shape == (3, 6, 24, 4)
assert S(weight)[:, 1:W(3)*2+1].shape == (3, 6, 24, 23) assert S(weight)[:, 1:W(3)*2+1].shape == (3, 6, 24, 23)
# Ellipsis
assert S(weight)[..., 9:13].shape == (3, 7, 24, 4)
assert S(weight)[:2, ..., 1:W(3)+1].shape == (2, 7, 24, 3)
assert S(weight)[..., 1:W(3)*2+1].shape == (3, 7, 24, 6)
assert S(weight)[..., :10, 1:W(3)*2+1].shape == (3, 7, 10, 6)
# no effect # no effect
assert S(weight)[:] is weight assert S(weight)[:] is weight
...@@ -227,6 +233,23 @@ def test_mixed_batchnorm2d(): ...@@ -227,6 +233,23 @@ def test_mixed_batchnorm2d():
_mixed_operation_differentiable_sanity_check(bn, torch.randn(2, 64, 3, 3)) _mixed_operation_differentiable_sanity_check(bn, torch.randn(2, 64, 3, 3))
def test_mixed_layernorm():
ln = LayerNorm(ValueChoice([32, 64], label='normalized_shape'), elementwise_affine=True)
assert _mixed_operation_sampling_sanity_check(ln, {'normalized_shape': 32}, torch.randn(2, 16, 32)).size(-1) == 32
assert _mixed_operation_sampling_sanity_check(ln, {'normalized_shape': 64}, torch.randn(2, 16, 64)).size(-1) == 64
_mixed_operation_differentiable_sanity_check(ln, torch.randn(2, 16, 64))
import itertools
ln = LayerNorm(ValueChoice(list(itertools.product([16, 32, 64], [8, 16])), label='normalized_shape'))
assert list(_mixed_operation_sampling_sanity_check(ln, {'normalized_shape': (16, 8)}, torch.randn(2, 16, 8)).shape[-2:]) == [16, 8]
assert list(_mixed_operation_sampling_sanity_check(ln, {'normalized_shape': (64, 16)}, torch.randn(2, 64, 16)).shape[-2:]) == [64, 16]
_mixed_operation_differentiable_sanity_check(ln, torch.randn(2, 64, 16))
def test_mixed_mhattn(): def test_mixed_mhattn():
mhattn = MultiheadAttention(ValueChoice([4, 8], label='emb'), 4) mhattn = MultiheadAttention(ValueChoice([4, 8], label='emb'), 4)
......
...@@ -78,6 +78,11 @@ def _strategy_factory(alias, space_type): ...@@ -78,6 +78,11 @@ def _strategy_factory(alias, space_type):
extra_mutation_hooks.append(NDSStagePathSampling.mutate) extra_mutation_hooks.append(NDSStagePathSampling.mutate)
else: else:
extra_mutation_hooks.append(NDSStageDifferentiable.mutate) extra_mutation_hooks.append(NDSStageDifferentiable.mutate)
# Autoformer search space require specific extra hooks
if space_type == 'autoformer':
from nni.retiarii.hub.pytorch.autoformer import MixedAbsPosEmbed, MixedClsToken
extra_mutation_hooks.extend([MixedAbsPosEmbed.mutate, MixedClsToken.mutate])
if alias == 'darts': if alias == 'darts':
return stg.DARTS(mutation_hooks=extra_mutation_hooks) return stg.DARTS(mutation_hooks=extra_mutation_hooks)
...@@ -149,7 +154,7 @@ def _dataset_factory(dataset_type, subset=20): ...@@ -149,7 +154,7 @@ def _dataset_factory(dataset_type, subset=20):
'mobilenetv3_small', 'mobilenetv3_small',
'proxylessnas', 'proxylessnas',
'shufflenet', 'shufflenet',
# 'autoformer', 'autoformer',
'nasnet', 'nasnet',
'enas', 'enas',
'amoeba', 'amoeba',
...@@ -186,7 +191,7 @@ def test_hub_oneshot(space_type, strategy_type): ...@@ -186,7 +191,7 @@ def test_hub_oneshot(space_type, strategy_type):
NDS_SPACES = ['amoeba', 'darts', 'pnas', 'enas', 'nasnet'] NDS_SPACES = ['amoeba', 'darts', 'pnas', 'enas', 'nasnet']
if strategy_type == 'proxyless': if strategy_type == 'proxyless':
if 'width' in space_type or 'depth' in space_type or \ if 'width' in space_type or 'depth' in space_type or \
any(space_type.startswith(prefix) for prefix in NDS_SPACES + ['proxylessnas', 'mobilenetv3']): any(space_type.startswith(prefix) for prefix in NDS_SPACES + ['proxylessnas', 'mobilenetv3', 'autoformer']):
pytest.skip('The space has used unsupported APIs.') pytest.skip('The space has used unsupported APIs.')
if strategy_type in ['darts', 'gumbel'] and space_type == 'mobilenetv3': if strategy_type in ['darts', 'gumbel'] and space_type == 'mobilenetv3':
pytest.skip('Skip as it consumes too much memory.') pytest.skip('Skip as it consumes too much memory.')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment