import inspect import torch import torch.nn as nn from .video_stems import get_default_stem from ._utils import Conv3DNoTemporal BLOCK_CONFIG = { 10: (1, 1, 1, 1), 16: (2, 2, 2, 1), 18: (2, 2, 2, 2), 26: (2, 3, 4, 3), 34: (3, 4, 6, 3), 50: (3, 4, 6, 3), 101: (3, 4, 23, 3), 152: (3, 8, 36, 3) } class BasicBlock(nn.Module): expansion = 1 def __init__(self, inplanes, planes, conv_builder, stride=1, downsample=None): midplanes = (inplanes * planes * 3 * 3 * 3) // (inplanes * 3 * 3 + 3 * planes) super(BasicBlock, self).__init__() self.conv1 = nn.Sequential( conv_builder(inplanes, planes, midplanes, stride), nn.BatchNorm3d(planes), nn.ReLU(inplace=True) ) self.conv2 = nn.Sequential( conv_builder(planes, planes, midplanes), nn.BatchNorm3d(planes) ) self.relu = nn.ReLU(inplace=True) self.downsample = downsample self.stride = stride def forward(self, x): residual = x out = self.conv1(x) out = self.conv2(out) if self.downsample is not None: residual = self.downsample(x) out += residual out = self.relu(out) return out class Bottleneck(nn.Module): expansion = 4 def __init__(self, inplanes, planes, conv_builder, stride=1, downsample=None): super(Bottleneck, self).__init__() midplanes = (inplanes * planes * 3 * 3 * 3) // (inplanes * 3 * 3 + 3 * planes) # 1x1x1 self.conv1 = nn.Sequential( nn.Conv3d(inplanes, planes, kernel_size=1, bias=False), nn.BatchNorm3d(planes), nn.ReLU(inplace=True) ) # Second kernel self.conv2 = nn.Sequential( conv_builder(planes, planes, midplanes, stride), nn.BatchNorm3d(planes), nn.ReLU(inplace=True) ) # 1x1x1 self.conv3 = nn.Sequential( nn.Conv3d(planes, planes * self.expansion, kernel_size=1, bias=False), nn.BatchNorm3d(planes * self.expansion) ) self.relu = nn.ReLU(inplace=True) self.downsample = downsample self.stride = stride def forward(self, x): residual = x out = self.conv1(x) out = self.conv2(out) out = self.conv3(out) if self.downsample is not None: residual = self.downsample(x) out += residual out = self.relu(out) return out class VideoTrunkBuilder(nn.Module): def __init__(self, block, conv_makers, model_depth, stem=None, num_classes=400, zero_init_residual=False): """Generic resnet video generator. Args: block (nn.Module): resnet building block conv_makers (list(functions)): generator function for each layer model_depth (int): depth of the model; supports traditional resnet depths . stem (nn.Sequential, optional): Resnet stem, if None, defaults to conv-bn-relu. Defaults to None. num_classes (int, optional): Dimension of the final FC layer. Defaults to 400. zero_init_residual (bool, optional): Zero init bottleneck residual BN. Defaults to False. """ super(VideoTrunkBuilder, self).__init__() layers = BLOCK_CONFIG[model_depth] self.inplanes = 64 if stem is None: self.conv1 = get_default_stem() else: self.conv1 = stem self.layer1 = self._make_layer(block, conv_makers[0], 64, layers[0], stride=1) self.layer2 = self._make_layer(block, conv_makers[1], 128, layers[1], stride=2) self.layer3 = self._make_layer(block, conv_makers[2], 256, layers[2], stride=2) self.layer4 = self._make_layer(block, conv_makers[3], 512, layers[3], stride=2) self.avgpool = nn.AdaptiveAvgPool3d((1, 1, 1)) self.fc = nn.Linear(512 * block.expansion, num_classes) # init weights self._initialize_weights() if zero_init_residual: for m in self.modules(): if isinstance(m, Bottleneck): nn.init.constant_(m.bn3.weight, 0) def forward(self, x): x = self.conv1(x) x = self.layer1(x) x = self.layer2(x) x = self.layer3(x) x = self.layer4(x) x = self.avgpool(x) # Flatten the layer to fc x = x.flatten(1) x = self.fc(x) return x def _make_layer(self, block, conv_builder, planes, blocks, stride=1): downsample = None if stride != 1 or self.inplanes != planes * block.expansion: ds_stride = conv_builder.get_downsample_stride(stride) downsample = nn.Sequential( nn.Conv3d(self.inplanes, planes * block.expansion, kernel_size=1, stride=ds_stride, bias=False), nn.BatchNorm3d(planes * block.expansion) ) layers = [] layers.append(block(self.inplanes, planes, conv_builder, stride, downsample)) self.inplanes = planes * block.expansion for i in range(1, blocks): layers.append(block(self.inplanes, planes, conv_builder)) return nn.Sequential(*layers) def _initialize_weights(self): for m in self.modules(): if isinstance(m, nn.Conv3d): nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') if m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.BatchNorm3d): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) elif isinstance(m, nn.Linear): nn.init.normal_(m.weight, 0, 0.01) nn.init.constant_(m.bias, 0)