# Copyright 2021 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # coding=utf-8 # Copyright 2021 The Google Research Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Contains definitions for the AssembleNet++ [2] models (without object input). Requires the AssembleNet++ architecture to be specified in FLAGS.model_structure (and optionally FLAGS.model_edge_weights). This is identical to the form described in assemblenet.py for the AssembleNet. Please check assemblenet.py for the detailed format of the model strings. AssembleNet++ adds `peer-attention' to the basic AssembleNet, which allows each conv. block connection to be conditioned differently based on another block [2]. It is a form of channel-wise attention. Note that we learn to apply attention independently for each frame. The `peer-attention' implementation in this file is the version that enables one-shot differentiable search of attention connectivity (Fig. 2 in [2]), using a softmax weighted summation of possible attention vectors. [2] Michael S. Ryoo, AJ Piergiovanni, Juhana Kangaspunta, Anelia Angelova, AssembleNet++: Assembling Modality Representations via Attention Connections. ECCV 2020 https://arxiv.org/abs/2008.08072 In order to take advantage of object inputs, one will need to set the flag FLAGS.use_object_input as True, and provide the list of input tensors as an input to the network, as shown in run_asn_with_object.py. This will require a pre-processed object data stream. It uses (2+1)D convolutions for video representations. The main AssembleNet++ takes a 4-D (N*T)HWC tensor as an input (i.e., the batch dim and time dim are mixed), and it reshapes a tensor to NT(H*W)C whenever a 1-D temporal conv. is necessary. This is to run this on TPU efficiently. """ import functools from typing import Any, Dict, List, Mapping, Optional from absl import logging import numpy as np import tensorflow as tf from official.modeling import hyperparams from official.projects.assemblenet.configs import assemblenet as cfg from official.projects.assemblenet.modeling import assemblenet as asn from official.projects.assemblenet.modeling import rep_flow_2d_layer as rf from official.vision.beta.modeling import factory_3d as model_factory from official.vision.beta.modeling.backbones import factory as backbone_factory layers = tf.keras.layers def softmax_merge_peer_attentions(peers): """Merge multiple peer-attention vectors with softmax weighted sum. Summation weights are to be learned. Args: peers: A list of `Tensors` of size `[batch*time, channels]`. Returns: The output `Tensor` of size `[batch*time, channels]. """ data_format = tf.keras.backend.image_data_format() dtype = peers[0].dtype assert data_format == 'channels_last' initial_attn_weights = tf.keras.initializers.TruncatedNormal(stddev=0.01)( [len(peers)]) attn_weights = tf.cast(tf.nn.softmax(initial_attn_weights), dtype) weighted_peers = [] for i, peer in enumerate(peers): weighted_peers.append(attn_weights[i] * peer) return tf.add_n(weighted_peers) def apply_attention(inputs, attention_mode=None, attention_in=None, use_5d_mode=False): """Applies peer-attention or self-attention to the input tensor. Depending on the attention_mode, this function either applies channel-wise self-attention or peer-attention. For the peer-attention, the function combines multiple candidate attention vectors (given as attention_in), by learning softmax-sum weights described in the AssembleNet++ paper. Note that the attention is applied individually for each frame, which showed better accuracies than using video-level attention. Args: inputs: A `Tensor`. Either 4D or 5D, depending of use_5d_mode. attention_mode: `str` specifying mode. If not `peer', does self-attention. attention_in: A list of `Tensors' of size [batch*time, channels]. use_5d_mode: `bool` indicating whether the inputs are in 5D tensor or 4D. Returns: The output `Tensor` after concatenation. """ data_format = tf.keras.backend.image_data_format() assert data_format == 'channels_last' if use_5d_mode: h_channel_loc = 2 else: h_channel_loc = 1 if attention_mode == 'peer': attn = softmax_merge_peer_attentions(attention_in) else: attn = tf.math.reduce_mean(inputs, [h_channel_loc, h_channel_loc + 1]) attn = tf.keras.layers.Dense( units=inputs.shape[-1], kernel_initializer=tf.random_normal_initializer(stddev=.01))( inputs=attn) attn = tf.math.sigmoid(attn) channel_attn = tf.expand_dims( tf.expand_dims(attn, h_channel_loc), h_channel_loc) inputs = tf.math.multiply(inputs, channel_attn) return inputs class _ApplyEdgeWeight(layers.Layer): """Multiply weight on each input tensor. A weight is assigned for each connection (i.e., each input tensor). This layer is used by the fusion_with_peer_attention to compute the weighted inputs. """ def __init__(self, weights_shape, index: Optional[int] = None, use_5d_mode: bool = False, model_edge_weights: Optional[List[Any]] = None, num_object_classes: Optional[int] = None, **kwargs): """Constructor. Args: weights_shape: A list of intergers. Each element means number of edges. index: `int` index of the block within the AssembleNet architecture. Used for summation weight initial loading. use_5d_mode: `bool` indicating whether the inputs are in 5D tensor or 4D. model_edge_weights: AssembleNet++ model structure connection weights in the string format. num_object_classes: Assemblenet++ structure used object inputs so we should use what dataset classes you might be use (e.g. ADE-20k 151 classes) **kwargs: pass through arguments. Returns: The output `Tensor` after concatenation. """ super(_ApplyEdgeWeight, self).__init__(**kwargs) self._weights_shape = weights_shape self._index = index self._use_5d_mode = use_5d_mode self._model_edge_weights = model_edge_weights self._num_object_classes = num_object_classes data_format = tf.keras.backend.image_data_format() assert data_format == 'channels_last' def get_config(self): config = { 'weights_shape': self._weights_shape, 'index': self._index, 'use_5d_mode': self._use_5d_mode, 'model_edge_weights': self._model_edge_weights, 'num_object_classes': self._num_object_classes } base_config = super(_ApplyEdgeWeight, self).get_config() return dict(list(base_config.items()) + list(config.items())) def build(self, input_shape: tf.TensorShape): if self._weights_shape[0] == 1: self._edge_weights = 1.0 return if self._index is None or not self._model_edge_weights: self._edge_weights = self.add_weight( shape=self._weights_shape, initializer=tf.keras.initializers.TruncatedNormal( mean=0.0, stddev=0.01), trainable=True, name='agg_weights') else: initial_weights_after_sigmoid = np.asarray( self._model_edge_weights[self._index][0]).astype('float32') # Initial_weights_after_sigmoid is never 0, as the initial weights are # based the results of a successful connectivity search. initial_weights = -np.log(1. / initial_weights_after_sigmoid - 1.) self._edge_weights = self.add_weight( shape=self._weights_shape, initializer=tf.constant_initializer(initial_weights), trainable=False, name='agg_weights') def call(self, inputs: List[tf.Tensor], training: Optional[bool] = None) -> Mapping[Any, List[tf.Tensor]]: use_5d_mode = self._use_5d_mode dtype = inputs[0].dtype assert len(inputs) > 1 if use_5d_mode: h_channel_loc = 2 else: h_channel_loc = 1 # get smallest spatial size and largest channels sm_size = [10000, 10000] lg_channel = 0 for inp in inputs: # assume batch X height x width x channels sm_size[0] = min(sm_size[0], inp.shape[h_channel_loc]) sm_size[1] = min(sm_size[1], inp.shape[h_channel_loc + 1]) # Note that, when using object inputs, object channel sizes are usually # big. Since we do not want the object channel size to increase the number # of parameters for every fusion, we exclude it when computing lg_channel. if inp.shape[-1] > lg_channel and inp.shape[-1] != self._num_object_classes: # pylint: disable=line-too-long lg_channel = inp.shape[3] # loads or creates weight variables to fuse multiple inputs weights = tf.math.sigmoid(tf.cast(self._edge_weights, dtype)) # Compute weighted inputs. We group inputs with the same channels. per_channel_inps = dict({0: []}) for i, inp in enumerate(inputs): if inp.shape[h_channel_loc] != sm_size[0] or inp.shape[h_channel_loc + 1] != sm_size[1]: # pylint: disable=line-too-long assert sm_size[0] != 0 ratio = (inp.shape[h_channel_loc] + 1) // sm_size[0] if use_5d_mode: inp = tf.keras.layers.MaxPool3D([1, ratio, ratio], [1, ratio, ratio], padding='same')( inp) else: inp = tf.keras.layers.MaxPool2D([ratio, ratio], ratio, padding='same')( inp) weights = tf.cast(weights, inp.dtype) if inp.shape[-1] in per_channel_inps: per_channel_inps[inp.shape[-1]].append(weights[i] * inp) else: per_channel_inps.update({inp.shape[-1]: [weights[i] * inp]}) return per_channel_inps def fusion_with_peer_attention(inputs: List[tf.Tensor], index: Optional[int] = None, attention_mode: Optional[str] = None, attention_in: Optional[List[tf.Tensor]] = None, use_5d_mode: bool = False, model_edge_weights: Optional[List[Any]] = None, num_object_classes: Optional[int] = None): """Weighted summation of multiple tensors, while using peer-attention. Summation weights are to be learned. Uses spatial max pooling and 1x1 conv. to match their sizes. Before the summation, each connection (i.e., each input) itself is scaled with channel-wise peer-attention. Notice that attention is applied for each connection, conditioned based on attention_in. Args: inputs: A list of `Tensors`. Either 4D or 5D, depending of use_5d_mode. index: `int` index of the block within the AssembleNet architecture. Used for summation weight initial loading. attention_mode: `str` specifying mode. If not `peer', does self-attention. attention_in: A list of `Tensors' of size [batch*time, channels]. use_5d_mode: `bool` indicating whether the inputs are in 5D tensor or 4D. model_edge_weights: AssembleNet model structure connection weights in the string format. num_object_classes: Assemblenet++ structure used object inputs so we should use what dataset classes you might be use (e.g. ADE-20k 151 classes) Returns: The output `Tensor` after concatenation. """ if use_5d_mode: h_channel_loc = 2 conv_function = asn.conv3d_same_padding else: h_channel_loc = 1 conv_function = asn.conv2d_fixed_padding # If only 1 input. if len(inputs) == 1: inputs[0] = apply_attention(inputs[0], attention_mode, attention_in, use_5d_mode) return inputs[0] # get smallest spatial size and largest channels sm_size = [10000, 10000] lg_channel = 0 for inp in inputs: # assume batch X height x width x channels sm_size[0] = min(sm_size[0], inp.shape[h_channel_loc]) sm_size[1] = min(sm_size[1], inp.shape[h_channel_loc + 1]) # Note that, when using object inputs, object channel sizes are usually big. # Since we do not want the object channel size to increase the number of # parameters for every fusion, we exclude it when computing lg_channel. if inp.shape[-1] > lg_channel and inp.shape[-1] != num_object_classes: # pylint: disable=line-too-long lg_channel = inp.shape[3] per_channel_inps = _ApplyEdgeWeight( weights_shape=[len(inputs)], index=index, use_5d_mode=use_5d_mode, model_edge_weights=model_edge_weights)( inputs) # Implementation of connectivity with peer-attention if attention_mode: for key, channel_inps in per_channel_inps.items(): for idx in range(len(channel_inps)): with tf.name_scope('Connection_' + str(key) + '_' + str(idx)): channel_inps[idx] = apply_attention(channel_inps[idx], attention_mode, attention_in, use_5d_mode) # Adding 1x1 conv layers (to match channel size) and fusing all inputs. # We add inputs with the same channels first before applying 1x1 conv to save # memory. inps = [] for key, channel_inps in per_channel_inps.items(): if len(channel_inps) < 1: continue if len(channel_inps) == 1: if key == lg_channel: inp = channel_inps[0] else: inp = conv_function( channel_inps[0], lg_channel, kernel_size=1, strides=1) inps.append(inp) else: if key == lg_channel: inp = tf.add_n(channel_inps) else: inp = conv_function( channel_inps[0], lg_channel, kernel_size=1, strides=1) inps.append(inp) return tf.add_n(inps) def object_conv_stem(inputs): """Layers for an object input stem. It expects its input tensor to have a separate channel for each object class. Each channel should be specify each object class. Args: inputs: A `Tensor`. Returns: The output `Tensor`. """ inputs = tf.keras.layers.MaxPool2D( pool_size=4, strides=4, padding='SAME')( inputs=inputs) inputs = tf.identity(inputs, 'initial_max_pool') return inputs class AssembleNetPlus(tf.keras.Model): """AssembleNet++ backbone.""" def __init__(self, block_fn, num_blocks: List[int], num_frames: int, model_structure: List[Any], input_specs: layers.InputSpec = layers.InputSpec( shape=[None, None, None, None, 3]), model_edge_weights: Optional[List[Any]] = None, use_object_input: bool = False, attention_mode: str = 'peer', bn_decay: float = rf.BATCH_NORM_DECAY, bn_epsilon: float = rf.BATCH_NORM_EPSILON, use_sync_bn: bool = False, **kwargs): """Generator for AssembleNet++ models. Args: block_fn: `function` for the block to use within the model. Currently only has `bottleneck_block_interleave as its option`. num_blocks: list of 4 `int`s denoting the number of blocks to include in each of the 4 block groups. Each group consists of blocks that take inputs of the same resolution. num_frames: the number of frames in the input tensor. model_structure: AssembleNetPlus model structure in the string format. input_specs: `tf.keras.layers.InputSpec` specs of the input tensor. Dimension should be `[batch*time, height, width, channels]`. model_edge_weights: AssembleNet model structure connection weight in the string format. use_object_input : 'bool' values whether using object inputs attention_mode : 'str' , default = 'self', If we use peer attention 'peer' bn_decay: `float` batch norm decay parameter to use. bn_epsilon: `float` batch norm epsilon parameter to use. use_sync_bn: use synchronized batch norm for TPU. **kwargs: pass through arguments. Returns: Model `function` that takes in `inputs` and `is_training` and returns the output `Tensor` of the AssembleNetPlus model. """ data_format = tf.keras.backend.image_data_format() # Creation of the model graph. logging.info('model_structure=%r', model_structure) logging.info('model_structure=%r', model_structure) logging.info('model_edge_weights=%r', model_edge_weights) structure = model_structure if use_object_input: original_inputs = tf.keras.Input(shape=input_specs[0].shape[1:]) object_inputs = tf.keras.Input(shape=input_specs[1].shape[1:]) input_specs = input_specs[0] else: original_inputs = tf.keras.Input(shape=input_specs.shape[1:]) object_inputs = None original_num_frames = num_frames assert num_frames > 0, f'Invalid num_frames {num_frames}' grouping = {-3: [], -2: [], -1: [], 0: [], 1: [], 2: [], 3: []} for i in range(len(structure)): grouping[structure[i][0]].append(i) stem_count = len(grouping[-3]) + len(grouping[-2]) + len(grouping[-1]) assert stem_count != 0 stem_filters = 128 // stem_count if len(input_specs.shape) == 5: first_dim = ( input_specs.shape[0] * input_specs.shape[1] if input_specs.shape[0] and input_specs.shape[1] else -1) reshape_inputs = tf.reshape(original_inputs, (first_dim,) + input_specs.shape[2:]) elif len(input_specs.shape) == 4: reshape_inputs = original_inputs else: raise ValueError( f'Expect input spec to be 4 or 5 dimensions {input_specs.shape}') if grouping[-2]: # Instead of loading optical flows as inputs from data pipeline, we are # applying the "Representation Flow" to RGB frames so that we can compute # the flow within TPU/GPU on fly. It's essentially optical flow since we # do it with RGBs. axis = 3 if data_format == 'channels_last' else 1 flow_inputs = rf.RepresentationFlow( original_num_frames, depth=reshape_inputs.shape.as_list()[axis], num_iter=40, bottleneck=1)( reshape_inputs) streams = [] for i in range(len(structure)): with tf.name_scope('Node_' + str(i)): if structure[i][0] == -1: inputs = asn.rgb_conv_stem( reshape_inputs, original_num_frames, stem_filters, temporal_dilation=structure[i][1], bn_decay=bn_decay, bn_epsilon=bn_epsilon, use_sync_bn=use_sync_bn) streams.append(inputs) elif structure[i][0] == -2: inputs = asn.flow_conv_stem( flow_inputs, stem_filters, temporal_dilation=structure[i][1], bn_decay=bn_decay, bn_epsilon=bn_epsilon, use_sync_bn=use_sync_bn) streams.append(inputs) elif structure[i][0] == -3: # In order to use the object inputs, you need to feed your object # input tensor here. inputs = object_conv_stem(object_inputs) streams.append(inputs) else: block_number = structure[i][0] combined_inputs = [ streams[structure[i][1][j]] for j in range(0, len(structure[i][1])) ] logging.info(grouping) nodes_below = [] for k in range(-3, structure[i][0]): nodes_below = nodes_below + grouping[k] peers = [] if attention_mode: lg_channel = -1 # To show structures for attention we show nodes_below logging.info(nodes_below) for k in nodes_below: logging.info(streams[k].shape) lg_channel = max(streams[k].shape[3], lg_channel) for node_index in nodes_below: attn = tf.reduce_mean(streams[node_index], [1, 2]) attn = tf.keras.layers.Dense( units=lg_channel, kernel_initializer=tf.random_normal_initializer(stddev=.01))( inputs=attn) peers.append(attn) combined_inputs = fusion_with_peer_attention( combined_inputs, index=i, attention_mode=attention_mode, attention_in=peers, use_5d_mode=False) graph = asn.block_group( inputs=combined_inputs, filters=structure[i][2], block_fn=block_fn, blocks=num_blocks[block_number], strides=structure[i][4], name='block_group' + str(i), block_level=structure[i][0], num_frames=num_frames, temporal_dilation=structure[i][3]) streams.append(graph) if use_object_input: inputs = [original_inputs, object_inputs] else: inputs = original_inputs super(AssembleNetPlus, self).__init__( inputs=inputs, outputs=streams, **kwargs) @tf.keras.utils.register_keras_serializable(package='Vision') class AssembleNetPlusModel(tf.keras.Model): """An AssembleNet++ model builder.""" def __init__(self, backbone, num_classes, num_frames: int, model_structure: List[Any], input_specs: Optional[Dict[str, tf.keras.layers.InputSpec]] = None, max_pool_predictions: bool = False, use_object_input: bool = False, **kwargs): if not input_specs: input_specs = { 'image': layers.InputSpec(shape=[None, None, None, None, 3]) } if use_object_input and 'object' not in input_specs: input_specs['object'] = layers.InputSpec(shape=[None, None, None, None]) self._self_setattr_tracking = False self._config_dict = { 'backbone': backbone, 'num_classes': num_classes, 'num_frames': num_frames, 'input_specs': input_specs, 'model_structure': model_structure, } self._input_specs = input_specs self._backbone = backbone grouping = {-3: [], -2: [], -1: [], 0: [], 1: [], 2: [], 3: []} for i in range(len(model_structure)): grouping[model_structure[i][0]].append(i) inputs = { k: tf.keras.Input(shape=v.shape[1:]) for k, v in input_specs.items() } if use_object_input: streams = self._backbone(inputs=[inputs['image'], inputs['object']]) else: streams = self._backbone(inputs=inputs['image']) outputs = asn.multi_stream_heads( streams, grouping[3], num_frames, num_classes, max_pool_predictions=max_pool_predictions) super(AssembleNetPlusModel, self).__init__( inputs=inputs, outputs=outputs, **kwargs) @property def checkpoint_items(self): """Returns a dictionary of items to be additionally checkpointed.""" return dict(backbone=self.backbone) @property def backbone(self): return self._backbone def get_config(self): return self._config_dict @classmethod def from_config(cls, config, custom_objects=None): return cls(**config) def assemblenet_plus(assemblenet_depth: int, num_classes: int, num_frames: int, model_structure: List[Any], input_specs: layers.InputSpec = layers.InputSpec( shape=[None, None, None, None, 3]), model_edge_weights: Optional[List[Any]] = None, use_object_input: bool = False, attention_mode: Optional[str] = None, max_pool_predictions: bool = False, **kwargs): """Returns the AssembleNet++ model for a given size and number of output classes.""" data_format = tf.keras.backend.image_data_format() assert data_format == 'channels_last' if assemblenet_depth not in asn.ASSEMBLENET_SPECS: raise ValueError('Not a valid assemblenet_depth:', assemblenet_depth) if use_object_input: # assuming input_specs = [vide, obj] when use_object_input = True input_specs_dict = {'image': input_specs[0], 'object': input_specs[1]} else: input_specs_dict = {'image': input_specs} params = asn.ASSEMBLENET_SPECS[assemblenet_depth] backbone = AssembleNetPlus( block_fn=params['block'], num_blocks=params['num_blocks'], num_frames=num_frames, model_structure=model_structure, input_specs=input_specs, model_edge_weights=model_edge_weights, use_object_input=use_object_input, attention_mode=attention_mode, **kwargs) return AssembleNetPlusModel( backbone, num_classes=num_classes, num_frames=num_frames, model_structure=model_structure, input_specs=input_specs_dict, use_object_input=use_object_input, max_pool_predictions=max_pool_predictions, **kwargs) @backbone_factory.register_backbone_builder('assemblenet_plus') def build_assemblenet_plus( input_specs: tf.keras.layers.InputSpec, backbone_config: hyperparams.Config, norm_activation_config: hyperparams.Config, l2_regularizer: Optional[tf.keras.regularizers.Regularizer] = None ) -> tf.keras.Model: """Builds assemblenet++ backbone.""" del l2_regularizer backbone_type = backbone_config.type backbone_cfg = backbone_config.get() assert backbone_type == 'assemblenet_plus' assemblenet_depth = int(backbone_cfg.model_id) if assemblenet_depth not in asn.ASSEMBLENET_SPECS: raise ValueError('Not a valid assemblenet_depth:', assemblenet_depth) model_structure, model_edge_weights = cfg.blocks_to_flat_lists( backbone_cfg.blocks) params = asn.ASSEMBLENET_SPECS[assemblenet_depth] block_fn = functools.partial( params['block'], use_sync_bn=norm_activation_config.use_sync_bn, bn_decay=norm_activation_config.norm_momentum, bn_epsilon=norm_activation_config.norm_epsilon) backbone = AssembleNetPlus( block_fn=block_fn, num_blocks=params['num_blocks'], num_frames=backbone_cfg.num_frames, model_structure=model_structure, input_specs=input_specs, model_edge_weights=model_edge_weights, use_object_input=backbone_cfg.use_object_input, attention_mode=backbone_cfg.attention_mode, use_sync_bn=norm_activation_config.use_sync_bn, bn_decay=norm_activation_config.norm_momentum, bn_epsilon=norm_activation_config.norm_epsilon) logging.info('Number of parameters in AssembleNet++ backbone: %f M.', backbone.count_params() / 10.**6) return backbone @model_factory.register_model_builder('assemblenet_plus') def build_assemblenet_plus_model( input_specs: tf.keras.layers.InputSpec, model_config: cfg.AssembleNetPlusModel, num_classes: int, l2_regularizer: Optional[tf.keras.regularizers.Regularizer] = None): """Builds assemblenet++ model.""" input_specs_dict = {'image': input_specs} backbone = build_assemblenet_plus(input_specs, model_config.backbone, model_config.norm_activation, l2_regularizer) backbone_cfg = model_config.backbone.get() model_structure, _ = cfg.blocks_to_flat_lists(backbone_cfg.blocks) model = AssembleNetPlusModel( backbone, num_classes=num_classes, num_frames=backbone_cfg.num_frames, model_structure=model_structure, input_specs=input_specs_dict, max_pool_predictions=model_config.max_pool_predictions, use_object_input=backbone_cfg.use_object_input) return model