import torch.nn as nn import torch.nn.functional as F def _make_divisible(v, divisor, min_value=None): """ This function is taken from the original tf repo. It ensures that all layers have a channel number that is divisible by 8 It can be seen here: https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py :param v: :param divisor: :param min_value: :return: """ if min_value is None: min_value = divisor new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) # Make sure that round down does not go down by more than 10%. if new_v < 0.9 * v: new_v += divisor return new_v from timm.models.layers import SqueezeExcite import torch import torch.nn.functional as F class GAM_Attention(nn.Module): def __init__(self, in_channels, rate=4): super(GAM_Attention, self).__init__() self.channel_attention = nn.Sequential( nn.Linear(in_channels, int(in_channels / rate)), # nn.ReLU(inplace=True), nn.GELU(), nn.Linear(int(in_channels / rate), in_channels) ) self.spatial_attention = nn.Sequential( nn.Conv2d(in_channels, int(in_channels / rate), kernel_size=7, padding=3), nn.BatchNorm2d(int(in_channels / rate)), # nn.ReLU(inplace=True), nn.GELU(), nn.Conv2d(int(in_channels / rate), in_channels, kernel_size=7, padding=3), nn.BatchNorm2d(in_channels) ) def forward(self, x): b, c, h, w = x.shape x_permute = x.permute(0, 2, 3, 1).view(b, -1, c) x_att_permute = self.channel_attention(x_permute).view(b, h, w, c) x_channel_att = x_att_permute.permute(0, 3, 1, 2).sigmoid() x = x * x_channel_att x_spatial_att = self.spatial_attention(x).sigmoid() out = x * x_spatial_att return out class Simam(torch.nn.Module): def __init__(self, channels = None, e_lambda = 1e-4): super(Simam, self).__init__() self.activaton = nn.Sigmoid() self.e_lambda = e_lambda def __repr__(self): s = self.__class__.__name__ + '(' s += ('lambda=%f)' % self.e_lambda) return s @staticmethod def get_module_name(): return "simam" def forward(self, x): b, c, h, w = x.size() n = w * h - 1 x_minus_mu_square = (x - x.mean(dim=[2,3], keepdim=True)).pow(2) y = x_minus_mu_square / (4 * (x_minus_mu_square.sum(dim=[2,3], keepdim=True) / n + self.e_lambda)) + 0.5 return x * self.activaton(y) class EMA(nn.Module): def __init__(self, channels, factor=8): super(EMA, self).__init__() self.groups = factor assert channels // self.groups > 0 self.softmax = nn.Softmax(-1) self.agp = nn.AdaptiveAvgPool2d((1, 1)) self.pool_h = nn.AdaptiveAvgPool2d((None, 1)) self.pool_w = nn.AdaptiveAvgPool2d((1, None)) self.gn = nn.GroupNorm(channels // self.groups, channels // self.groups) self.conv1x1 = nn.Conv2d(channels // self.groups, channels // self.groups, kernel_size=1, stride=1, padding=0) self.conv3x3 = nn.Conv2d(channels // self.groups, channels // self.groups, kernel_size=3, stride=1, padding=1) def forward(self, x): b, c, h, w = x.size() group_x = x.reshape(b * self.groups, -1, h, w) # b*g,c//g,h,w x_h = self.pool_h(group_x) x_w = self.pool_w(group_x).permute(0, 1, 3, 2) hw = self.conv1x1(torch.cat([x_h, x_w], dim=2)) x_h, x_w = torch.split(hw, [h, w], dim=2) x1 = self.gn(group_x * x_h.sigmoid() * x_w.permute(0, 1, 3, 2).sigmoid()) x2 = self.conv3x3(group_x) x11 = self.softmax(self.agp(x1).reshape(b * self.groups, -1, 1).permute(0, 2, 1)) x12 = x2.reshape(b * self.groups, c // self.groups, -1) # b*g, c//g, hw x21 = self.softmax(self.agp(x2).reshape(b * self.groups, -1, 1).permute(0, 2, 1)) x22 = x1.reshape(b * self.groups, c // self.groups, -1) # b*g, c//g, hw weights = (torch.matmul(x11, x12) + torch.matmul(x21, x22)).reshape(b * self.groups, 1, h, w) return (group_x * weights.sigmoid()).reshape(b, c, h, w) class Conv2d_BN(torch.nn.Sequential): def __init__(self, a, b, ks=1, stride=1, pad=0, dilation=1, groups=1, bn_weight_init=1, resolution=-10000): super().__init__() self.add_module('c', torch.nn.Conv2d( a, b, ks, stride, pad, dilation, groups, bias=False)) self.add_module('bn', torch.nn.BatchNorm2d(b)) torch.nn.init.constant_(self.bn.weight, bn_weight_init) torch.nn.init.constant_(self.bn.bias, 0) @torch.no_grad() def fuse(self): c, bn = self._modules.values() w = bn.weight / (bn.running_var + bn.eps)**0.5 w = c.weight * w[:, None, None, None] b = bn.bias - bn.running_mean * bn.weight / \ (bn.running_var + bn.eps)**0.5 m = torch.nn.Conv2d(w.size(1) * self.c.groups, w.size( 0), w.shape[2:], stride=self.c.stride, padding=self.c.padding, dilation=self.c.dilation, groups=self.c.groups, device=c.weight.device) m.weight.data.copy_(w) m.bias.data.copy_(b) return m class Residual(torch.nn.Module): def __init__(self, m, drop=0.): super().__init__() self.m = m self.drop = drop def forward(self, x): if self.training and self.drop > 0: return x + self.m(x) * torch.rand(x.size(0), 1, 1, 1, device=x.device).ge_(self.drop).div(1 - self.drop).detach() else: return x + self.m(x) @torch.no_grad() def fuse(self): if isinstance(self.m, Conv2d_BN): m = self.m.fuse() assert(m.groups == m.in_channels) identity = torch.ones(m.weight.shape[0], m.weight.shape[1], 1, 1) identity = torch.nn.functional.pad(identity, [1,1,1,1]) m.weight += identity.to(m.weight.device) return m elif isinstance(self.m, torch.nn.Conv2d): m = self.m assert(m.groups != m.in_channels) identity = torch.ones(m.weight.shape[0], m.weight.shape[1], 1, 1) identity = torch.nn.functional.pad(identity, [1,1,1,1]) m.weight += identity.to(m.weight.device) return m else: return self class RepVGGDW(torch.nn.Module): def __init__(self, ed) -> None: super().__init__() self.conv = Conv2d_BN(ed, ed, 3, 1, 1, groups=ed) self.conv1 = torch.nn.Conv2d(ed, ed, 1, 1, 0, groups=ed) self.dim = ed self.bn = torch.nn.BatchNorm2d(ed) def forward(self, x): return self.bn((self.conv(x) + self.conv1(x)) + x) @torch.no_grad() def fuse(self): conv = self.conv.fuse() conv1 = self.conv1 conv_w = conv.weight conv_b = conv.bias conv1_w = conv1.weight conv1_b = conv1.bias conv1_w = torch.nn.functional.pad(conv1_w, [1,1,1,1]) identity = torch.nn.functional.pad(torch.ones(conv1_w.shape[0], conv1_w.shape[1], 1, 1, device=conv1_w.device), [1,1,1,1]) final_conv_w = conv_w + conv1_w + identity final_conv_b = conv_b + conv1_b conv.weight.data.copy_(final_conv_w) conv.bias.data.copy_(final_conv_b) bn = self.bn w = bn.weight / (bn.running_var + bn.eps)**0.5 w = conv.weight * w[:, None, None, None] b = bn.bias + (conv.bias - bn.running_mean) * bn.weight / \ (bn.running_var + bn.eps)**0.5 conv.weight.data.copy_(w) conv.bias.data.copy_(b) return conv class RepViTBlock(nn.Module): def __init__(self, inp, hidden_dim, oup, kernel_size, stride, use_se, use_hs): super(RepViTBlock, self).__init__() assert stride in [1, 2] self.identity = stride == 1 and inp == oup assert(hidden_dim == 2 * inp) if stride == 2: self.token_mixer = nn.Sequential( Conv2d_BN(inp, inp, kernel_size, stride, (kernel_size - 1) // 2, groups=inp), RepVGGDW(inp), SqueezeExcite(inp, 0.25) if use_se else nn.Identity(), Conv2d_BN(inp, oup, ks=1, stride=1, pad=0) ) self.channel_mixer = Residual(nn.Sequential( # pw Conv2d_BN(oup, 2 * oup, 1, 1, 0), nn.GELU() if use_hs else nn.GELU(), # pw-linear Conv2d_BN(2 * oup, oup, 1, 1, 0, bn_weight_init=0), )) else: assert(self.identity) self.token_mixer = nn.Sequential( RepVGGDW(inp), SqueezeExcite(inp, 0.25) if use_se else nn.Identity(), ) self.channel_mixer = Residual(nn.Sequential( # pw Conv2d_BN(inp, hidden_dim, 1, 1, 0), nn.GELU() if use_hs else nn.GELU(), # pw-linear Conv2d_BN(hidden_dim, oup, 1, 1, 0, bn_weight_init=0), )) def forward(self, x): return self.channel_mixer(self.token_mixer(x)) from timm.models.vision_transformer import trunc_normal_ class BN_Linear(torch.nn.Sequential): def __init__(self, a, b, bias=True, std=0.02): super().__init__() self.add_module('bn', torch.nn.BatchNorm1d(a)) self.add_module('l', torch.nn.Linear(a, b, bias=bias)) trunc_normal_(self.l.weight, std=std) if bias: torch.nn.init.constant_(self.l.bias, 0) @torch.no_grad() def fuse(self): bn, l = self._modules.values() w = bn.weight / (bn.running_var + bn.eps)**0.5 b = bn.bias - self.bn.running_mean * \ self.bn.weight / (bn.running_var + bn.eps)**0.5 w = l.weight * w[None, :] if l.bias is None: b = b @ self.l.weight.T else: b = (l.weight @ b[:, None]).view(-1) + self.l.bias m = torch.nn.Linear(w.size(1), w.size(0), device=l.weight.device) m.weight.data.copy_(w) m.bias.data.copy_(b) return m class Classfier(nn.Module): def __init__(self, dim, num_classes, distillation=True): super().__init__() self.classifier = BN_Linear(dim, num_classes) if num_classes > 0 else torch.nn.Identity() self.distillation = distillation if distillation: self.classifier_dist = BN_Linear(dim, num_classes) if num_classes > 0 else torch.nn.Identity() def forward(self, x): if self.distillation: x = self.classifier(x), self.classifier_dist(x) if not self.training: x = (x[0] + x[1]) / 2 else: x = self.classifier(x) return x @torch.no_grad() def fuse(self): classifier = self.classifier.fuse() if self.distillation: classifier_dist = self.classifier_dist.fuse() classifier.weight += classifier_dist.weight classifier.bias += classifier_dist.bias classifier.weight /= 2 classifier.bias /= 2 return classifier else: return classifier class RepViT(nn.Module): def __init__(self, cfgs, num_classes=1000, distillation=False): super(RepViT, self).__init__() # setting of inverted residual blocks self.cfgs = cfgs # building first layer input_channel = self.cfgs[0][2] patch_embed = torch.nn.Sequential(Conv2d_BN(3, input_channel // 2, 3, 2, 1), torch.nn.GELU(), Conv2d_BN(input_channel // 2, input_channel, 3, 2, 1)) layers = [patch_embed] layers.append(GAM_Attention(in_channels=input_channel)) layers.append(Simam()) # building inverted residual blocks block = RepViTBlock for k, t, c, use_se, use_hs, s in self.cfgs: output_channel = _make_divisible(c, 8) exp_size = _make_divisible(input_channel * t, 8) layers.append(block(input_channel, exp_size, output_channel, k, s, use_se, use_hs)) input_channel = output_channel self.features = nn.ModuleList(layers) self.ema = EMA(output_channel) self.classifier = Classfier(output_channel, num_classes, distillation) # self._initialize_weights() def forward(self, x): # x = self.features(x) for i, f in enumerate(self.features): # print(f.__class__.__name__, 'features shape: \t',f(x).shape) x = f(x) x = self.ema(x) x = torch.nn.functional.adaptive_avg_pool2d(x, 1).flatten(1) #[1, 256] x = self.classifier(x) return x def _initialize_weights(self): for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') if m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.Linear): nn.init.normal_(m.weight, 0, 0.01) nn.init.constant_(m.bias, 0) from timm.models import register_model @register_model def repvit_m0_6(pretrained=False, num_classes = 1000, distillation=False): """ Constructs a MobileNetV3-Large model """ cfgs = [ [3, 2, 40, 1, 0, 1], [3, 2, 40, 0, 0, 1], [3, 2, 80, 0, 0, 2], [3, 2, 80, 1, 0, 1], [3, 2, 80, 0, 0, 1], [3, 2, 160, 0, 1, 2], [3, 2, 160, 1, 1, 1], [3, 2, 160, 0, 1, 1], [3, 2, 160, 1, 1, 1], [3, 2, 160, 0, 1, 1], [3, 2, 160, 1, 1, 1], [3, 2, 160, 0, 1, 1], [3, 2, 160, 1, 1, 1], [3, 2, 160, 0, 1, 1], [3, 2, 160, 0, 1, 1], [3, 2, 320, 0, 1, 2], [3, 2, 320, 1, 1, 1], ] return RepViT(cfgs, num_classes=num_classes, distillation=distillation) @register_model def repvit_m0_9(pretrained=False, num_classes = 1000, distillation=False): """ Constructs a MobileNetV3-Large model """ cfgs = [ # k, t, c, SE, HS, s [3, 2, 48, 1, 0, 1], [3, 2, 48, 0, 0, 1], [3, 2, 48, 0, 0, 1], [3, 2, 96, 0, 0, 2], [3, 2, 96, 1, 0, 1], [3, 2, 96, 0, 0, 1], [3, 2, 96, 0, 0, 1], [3, 2, 192, 0, 1, 2], [3, 2, 192, 1, 1, 1], [3, 2, 192, 0, 1, 1], [3, 2, 192, 1, 1, 1], [3, 2, 192, 0, 1, 1], [3, 2, 192, 1, 1, 1], [3, 2, 192, 0, 1, 1], [3, 2, 192, 1, 1, 1], [3, 2, 192, 0, 1, 1], [3, 2, 192, 1, 1, 1], [3, 2, 192, 0, 1, 1], [3, 2, 192, 1, 1, 1], [3, 2, 192, 0, 1, 1], [3, 2, 192, 1, 1, 1], [3, 2, 192, 0, 1, 1], [3, 2, 192, 0, 1, 1], [3, 2, 384, 0, 1, 2], [3, 2, 384, 1, 1, 1], [3, 2, 384, 0, 1, 1] ] return RepViT(cfgs, num_classes=num_classes, distillation=distillation) ''' @register_model def repvit_m0_9(pretrained=False, num_classes = 1000, distillation=False):#单独减少通道掉点接近9个 """ Constructs a MobileNetV3-Large model """ cfgs = [ # k, t, c, SE, HS, s [3, 2, 48, 1, 0, 1], [3, 2, 48, 0, 0, 1], [3, 2, 48, 0, 0, 1], [3, 2, 64, 0, 0, 2], [3, 2, 64, 1, 0, 1], [3, 2, 64, 0, 0, 1], [3, 2, 64, 0, 0, 1], [3, 2, 128, 0, 1, 2], [3, 2, 128, 1, 1, 1], [3, 2, 128, 0, 1, 1], [3, 2, 128, 1, 1, 1], [3, 2, 128, 0, 1, 1], [3, 2, 128, 1, 1, 1], [3, 2, 128, 0, 1, 1], [3, 2, 128, 1, 1, 1], [3, 2, 128, 0, 1, 1], [3, 2, 128, 1, 1, 1], [3, 2, 128, 0, 1, 1], [3, 2, 128, 1, 1, 1], [3, 2, 128, 0, 1, 1], [3, 2, 128, 1, 1, 1], [3, 2, 128, 0, 1, 1], [3, 2, 128, 0, 1, 1], [3, 2, 256, 0, 1, 2], [3, 2, 256, 1, 1, 1], [3, 2, 256, 0, 1, 1] ] return RepViT(cfgs, num_classes=num_classes, distillation=distillation) ''' @register_model def repvit_m1_0(pretrained=False, num_classes = 1000, distillation=False): """ Constructs a MobileNetV3-Large model """ cfgs = [ # k, t, c, SE, HS, s [3, 2, 56, 1, 0, 1], [3, 2, 56, 0, 0, 1], [3, 2, 56, 0, 0, 1], [3, 2, 112, 0, 0, 2], [3, 2, 112, 1, 0, 1], [3, 2, 112, 0, 0, 1], [3, 2, 112, 0, 0, 1], [3, 2, 224, 0, 1, 2], [3, 2, 224, 1, 1, 1], [3, 2, 224, 0, 1, 1], [3, 2, 224, 1, 1, 1], [3, 2, 224, 0, 1, 1], [3, 2, 224, 1, 1, 1], [3, 2, 224, 0, 1, 1], [3, 2, 224, 1, 1, 1], [3, 2, 224, 0, 1, 1], [3, 2, 224, 1, 1, 1], [3, 2, 224, 0, 1, 1], [3, 2, 224, 1, 1, 1], [3, 2, 224, 0, 1, 1], [3, 2, 224, 1, 1, 1], [3, 2, 224, 0, 1, 1], [3, 2, 224, 0, 1, 1], [3, 2, 448, 0, 1, 2], [3, 2, 448, 1, 1, 1], [3, 2, 448, 0, 1, 1] ] return RepViT(cfgs, num_classes=num_classes, distillation=distillation) @register_model def repvit_m1_1(pretrained=False, num_classes = 1000, distillation=False): """ Constructs a MobileNetV3-Large model """ cfgs = [ # k, t, c, SE, HS, s [3, 2, 64, 1, 0, 1], [3, 2, 64, 0, 0, 1], [3, 2, 64, 0, 0, 1], [3, 2, 128, 0, 0, 2], [3, 2, 128, 1, 0, 1], [3, 2, 128, 0, 0, 1], [3, 2, 128, 0, 0, 1], [3, 2, 256, 0, 1, 2], [3, 2, 256, 1, 1, 1], [3, 2, 256, 0, 1, 1], [3, 2, 256, 1, 1, 1], [3, 2, 256, 0, 1, 1], [3, 2, 256, 1, 1, 1], [3, 2, 256, 0, 1, 1], [3, 2, 256, 1, 1, 1], [3, 2, 256, 0, 1, 1], [3, 2, 256, 1, 1, 1], [3, 2, 256, 0, 1, 1], [3, 2, 256, 1, 1, 1], [3, 2, 256, 0, 1, 1], [3, 2, 256, 0, 1, 1], [3, 2, 512, 0, 1, 2], [3, 2, 512, 1, 1, 1], [3, 2, 512, 0, 1, 1] ] return RepViT(cfgs, num_classes=num_classes, distillation=distillation) @register_model def repvit_m1_5(pretrained=False, num_classes = 1000, distillation=False): """ Constructs a MobileNetV3-Large model """ cfgs = [ # k, t, c, SE, HS, s [3, 2, 64, 1, 0, 1], [3, 2, 64, 0, 0, 1], [3, 2, 64, 1, 0, 1], [3, 2, 64, 0, 0, 1], [3, 2, 64, 0, 0, 1], [3, 2, 128, 0, 0, 2], [3, 2, 128, 1, 0, 1], [3, 2, 128, 0, 0, 1], [3, 2, 128, 1, 0, 1], [3, 2, 128, 0, 0, 1], [3, 2, 128, 0, 0, 1], [3, 2, 256, 0, 1, 2], [3, 2, 256, 1, 1, 1], [3, 2, 256, 0, 1, 1], [3, 2, 256, 1, 1, 1], [3, 2, 256, 0, 1, 1], [3, 2, 256, 1, 1, 1], [3, 2, 256, 0, 1, 1], [3, 2, 256, 1, 1, 1], [3, 2, 256, 0, 1, 1], [3, 2, 256, 1, 1, 1], [3, 2, 256, 0, 1, 1], [3, 2, 256, 1, 1, 1], [3, 2, 256, 0, 1, 1], [3, 2, 256, 1, 1, 1], [3, 2, 256, 0, 1, 1], [3, 2, 256, 1, 1, 1], [3, 2, 256, 0, 1, 1], [3, 2, 256, 1, 1, 1], [3, 2, 256, 0, 1, 1], [3, 2, 256, 1, 1, 1], [3, 2, 256, 0, 1, 1], [3, 2, 256, 1, 1, 1], [3, 2, 256, 0, 1, 1], [3, 2, 256, 1, 1, 1], [3, 2, 256, 0, 1, 1], [3, 2, 256, 0, 1, 1], [3, 2, 512, 0, 1, 2], [3, 2, 512, 1, 1, 1], [3, 2, 512, 0, 1, 1], [3, 2, 512, 1, 1, 1], [3, 2, 512, 0, 1, 1] ] return RepViT(cfgs, num_classes=num_classes, distillation=distillation) @register_model def repvit_m2_3(pretrained=False, num_classes = 1000, distillation=False): """ Constructs a MobileNetV3-Large model """ cfgs = [ # k, t, c, SE, HS, s [3, 2, 80, 1, 0, 1], [3, 2, 80, 0, 0, 1], [3, 2, 80, 1, 0, 1], [3, 2, 80, 0, 0, 1], [3, 2, 80, 1, 0, 1], [3, 2, 80, 0, 0, 1], [3, 2, 80, 0, 0, 1], [3, 2, 160, 0, 0, 2], [3, 2, 160, 1, 0, 1], [3, 2, 160, 0, 0, 1], [3, 2, 160, 1, 0, 1], [3, 2, 160, 0, 0, 1], [3, 2, 160, 1, 0, 1], [3, 2, 160, 0, 0, 1], [3, 2, 160, 0, 0, 1], [3, 2, 320, 0, 1, 2], [3, 2, 320, 1, 1, 1], [3, 2, 320, 0, 1, 1], [3, 2, 320, 1, 1, 1], [3, 2, 320, 0, 1, 1], [3, 2, 320, 1, 1, 1], [3, 2, 320, 0, 1, 1], [3, 2, 320, 1, 1, 1], [3, 2, 320, 0, 1, 1], [3, 2, 320, 1, 1, 1], [3, 2, 320, 0, 1, 1], [3, 2, 320, 1, 1, 1], [3, 2, 320, 0, 1, 1], [3, 2, 320, 1, 1, 1], [3, 2, 320, 0, 1, 1], [3, 2, 320, 1, 1, 1], [3, 2, 320, 0, 1, 1], [3, 2, 320, 1, 1, 1], [3, 2, 320, 0, 1, 1], [3, 2, 320, 1, 1, 1], [3, 2, 320, 0, 1, 1], [3, 2, 320, 1, 1, 1], [3, 2, 320, 0, 1, 1], [3, 2, 320, 1, 1, 1], [3, 2, 320, 0, 1, 1], [3, 2, 320, 1, 1, 1], [3, 2, 320, 0, 1, 1], [3, 2, 320, 1, 1, 1], [3, 2, 320, 0, 1, 1], [3, 2, 320, 1, 1, 1], [3, 2, 320, 0, 1, 1], [3, 2, 320, 1, 1, 1], [3, 2, 320, 0, 1, 1], [3, 2, 320, 1, 1, 1], [3, 2, 320, 0, 1, 1], # [3, 2, 320, 1, 1, 1], # [3, 2, 320, 0, 1, 1], [3, 2, 320, 0, 1, 1], [3, 2, 640, 0, 1, 2], [3, 2, 640, 1, 1, 1], [3, 2, 640, 0, 1, 1], # [3, 2, 640, 1, 1, 1], # [3, 2, 640, 0, 1, 1] ] return RepViT(cfgs, num_classes=num_classes, distillation=distillation)