Unverified Commit 2a370fe0 authored by Jekaterina Jaroslavceva's avatar Jekaterina Jaroslavceva Committed by GitHub
Browse files

Ranking losses, normalization and pooling layers. (#9727)

* Ranking losses, normalization and pooling layers.

* Ranking losses, normalization and pooling layers.
parent c508968c
# Copyright 2021 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
\ No newline at end of file
# Copyright 2021 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Normalization layer definitions."""
import tensorflow as tf
class L2Normalization(tf.keras.layers.Layer):
"""Normalization layer using L2 norm."""
def __init__(self):
"""Initialization of the L2Normalization layer."""
super(L2Normalization, self).__init__()
# A lower bound value for the norm.
self.eps = 1e-6
def call(self, x, axis=1):
"""Invokes the L2Normalization instance.
Args:
x: A Tensor.
axis: Dimension along which to normalize. A scalar or a vector of
integers.
Returns:
norm: A Tensor with the same shape as `x`.
"""
return tf.nn.l2_normalize(x, axis, epsilon=self.eps)
# Copyright 2021 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for normalization layers."""
import tensorflow as tf
from delf.python.normalization_layers import normalization
class NormalizationsTest(tf.test.TestCase):
def testL2Normalization(self):
x = tf.constant([-4.0, 0.0, 4.0])
layer = normalization.L2Normalization()
# Run tested function.
result = layer(x, axis=0)
# Define expected result.
exp_output = [-0.70710677, 0.0, 0.70710677]
# Compare actual and expected.
self.assertAllClose(exp_output, result)
if __name__ == '__main__':
tf.test.main()
# Copyright 2021 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
\ No newline at end of file
# Copyright 2021 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Pooling layers definitions."""
import tensorflow as tf
class MAC(tf.keras.layers.Layer):
"""Global max pooling (MAC) layer.
Maximum Activations of Convolutions (MAC) is simply constructed by
max-pooling over all dimensions per feature map. See
https://arxiv.org/abs/1511.05879 for a reference.
"""
def __init__(self):
"""Initialization of the global max pooling (MAC) layer."""
super(MAC, self).__init__()
def call(self, x, axis=[1, 2]):
"""Invokes the MAC pooling instance.
Args:
x: [B, H, W, D] A float32 Tensor.
axis: Dimensions to reduce.
Returns:
output: [B, D] A float32 Tensor.
"""
return mac(x, axis=axis)
class SPoC(tf.keras.layers.Layer):
"""Average pooling (SPoC) layer.
Sum-pooled convolutional features (SPoC) is based on the sum pooling of the
deep features. See https://arxiv.org/pdf/1510.07493.pdf for a reference."""
def __init__(self):
"""Initialization of the SPoC layer."""
super(SPoC, self).__init__()
def call(self, x, axis=[1, 2]):
"""Invokes the SPoC instance.
Args:
x: [B, H, W, D] A float32 Tensor.
axis: Dimensions to reduce.
Returns:
output: [B, D] A float32 Tensor.
"""
return spoc(x, axis)
class GeM(tf.keras.layers.Layer):
"""Generalized mean pooling (GeM) layer.
Generalized Mean Pooling (GeM) computes the generalized mean of each
channel in a tensor. See https://arxiv.org/abs/1711.02512 for a reference.
"""
def __init__(self, power=3.):
"""Initialization of the generalized mean pooling (GeM) layer.
Args:
power: Float power > 0 is an inverse exponent parameter, used during
the generalized mean pooling computation. Setting this exponent as power
> 1 increases the contrast of the pooled feature map and focuses on
the salient features of the image. GeM is a generalization of the
average pooling commonly used in classification networks (power = 1) and
of spatial max-pooling layer (power = inf).
"""
super(GeM, self).__init__()
self.power = power
self.eps = 1e-6
def call(self, x, axis=[1, 2]):
"""Invokes the GeM instance.
Args:
x: [B, H, W, D] A float32 Tensor.
axis: Dimensions to reduce.
Returns:
output: [B, D] A float32 Tensor.
"""
return gem(x, power=self.power, eps=self.eps, axis=axis)
def mac(x, axis=[1, 2]):
"""Performs global max pooling (MAC).
Args:
x: [B, H, W, D] A float32 Tensor.
axis: Dimensions to reduce.
Returns:
output: [B, D] A float32 Tensor.
"""
return tf.reduce_max(x, axis=axis, keepdims=False)
def spoc(x, axis=[1, 2]):
"""Performs average pooling (SPoC).
Args:
x: [B, H, W, D] A float32 Tensor.
axis: Dimensions to reduce.
Returns:
output: [B, D] A float32 Tensor.
"""
return tf.reduce_mean(x, axis=axis, keepdims=False)
def gem(x, axis=[1, 2], power=3., eps=1e-6):
"""Performs generalized mean pooling (GeM).
Args:
x: [B, H, W, D] A float32 Tensor.
axis: Dimensions to reduce.
power: Float, power > 0 is an inverse exponent parameter (GeM power).
eps: Float, parameter for numerical stability.
Returns:
output: [B, D] A float32 Tensor.
"""
tmp = tf.pow(tf.maximum(x, eps), power)
out = tf.pow(tf.reduce_mean(tmp, axis=axis, keepdims=False), 1. / power)
return out
# Lint as: python3 # Copyright 2021 The TensorFlow Authors All Rights Reserved.
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -13,41 +12,44 @@ ...@@ -13,41 +12,44 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Tests for the ResNet backbone.""" """Tests for pooling layers."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import tensorflow as tf import tensorflow as tf
from delf.python.training.model import resnet50 from delf.python.pooling_layers import pooling
class Resnet50Test(tf.test.TestCase):
def test_gem_pooling_works(self): class PoolingsTest(tf.test.TestCase):
# Input feature map: Batch size = 2, height = 1, width = 2, depth = 2.
feature_map = tf.constant([[[[.0, 2.0], [1.0, -1.0]]],
[[[1.0, 100.0], [1.0, .0]]]],
dtype=tf.float32)
power = 2.0
threshold = .0
def testMac(self):
x = tf.constant([[[[0., 1.], [2., 3.]],
[[4., 5.], [6., 7.]]]])
# Run tested function. # Run tested function.
pooled_feature_map = resnet50.gem_pooling(feature_map=feature_map, result = pooling.mac(x)
axis=[1, 2], # Define expected result.
power=power, exp_output = [[6., 7.]]
threshold=threshold) # Compare actual and expected.
self.assertAllClose(exp_output, result)
def testSpoc(self):
x = tf.constant([[[[0., 1.], [2., 3.]],
[[4., 5.], [6., 7.]]]])
# Run tested function.
result = pooling.spoc(x)
# Define expected result. # Define expected result.
expected_pooled_feature_map = np.array([[0.707107, 1.414214], exp_output = [[3., 4.]]
[1.0, 70.710678]], # Compare actual and expected.
dtype=float) self.assertAllClose(exp_output, result)
def testGem(self):
x = tf.constant([[[[0., 1.], [2., 3.]],
[[4., 5.], [6., 7.]]]])
# Run tested function.
result = pooling.gem(x, power=3., eps=1e-6)
# Define expected result.
exp_output = [[4.1601677, 4.9866314]]
# Compare actual and expected. # Compare actual and expected.
self.assertAllClose(pooled_feature_map, expected_pooled_feature_map) self.assertAllClose(exp_output, result)
if __name__ == '__main__': if __name__ == '__main__':
......
# Copyright 2021 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
\ No newline at end of file
# Copyright 2021 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Ranking loss definitions."""
import tensorflow as tf
class ContrastiveLoss(tf.keras.losses.Loss):
"""Contrastive Loss layer.
Contrastive Loss layer allows to compute contrastive loss for a batch of
images. Implementation based on: https://arxiv.org/abs/1604.02426.
"""
def __init__(self, margin=0.7, reduction=tf.keras.losses.Reduction.NONE):
"""Initialization of Contrastive Loss layer.
Args:
margin: Float contrastive loss margin.
reduction: Type of loss reduction.
"""
super(ContrastiveLoss, self).__init__(reduction)
self.margin = margin
# Parameter for numerical stability.
self.eps = 1e-6
def __call__(self, queries, positives, negatives):
"""Invokes the Contrastive Loss instance.
Args:
queries: [B, D] Anchor input tensor.
positives: [B, D] Positive sample input tensor.
negatives: [B, Nneg, D] Negative sample input tensor.
Returns:
loss: Scalar tensor.
"""
return contrastive_loss(queries, positives, negatives,
margin=self.margin, eps=self.eps)
class TripletLoss(tf.keras.losses.Loss):
"""Triplet Loss layer.
Triplet Loss layer computes triplet loss for a batch of images. Triplet
loss tries to keep all queries closer to positives than to any negatives.
Margin is used to specify when a triplet has become too "easy" and we no
longer want to adjust the weights from it. Differently from the Contrastive
Loss, Triplet Loss uses squared distances when computing the loss.
Implementation based on: https://arxiv.org/abs/1511.07247.
"""
def __init__(self, margin=0.1, reduction=tf.keras.losses.Reduction.NONE):
"""Initialization of Triplet Loss layer.
Args:
margin: Triplet loss margin.
reduction: Type of loss reduction.
"""
super(TripletLoss, self).__init__(reduction)
self.margin = margin
def __call__(self, queries, positives, negatives):
"""Invokes the Triplet Loss instance.
Args:
queries: [B, D] Anchor input tensor.
positives: [B, D] Positive sample input tensor.
negatives: [B, Nneg, D] Negative sample input tensor.
Returns:
loss: Scalar tensor.
"""
return triplet_loss(queries, positives, negatives, margin=self.margin)
def contrastive_loss(queries, positives, negatives, margin=0.7,
eps=1e-6):
"""Calculates Contrastive Loss.
We expect the `queries`, `positives` and `negatives` to be normalized with
unit length for training stability. The contrastive loss directly
optimizes this distance by encouraging all positive distances to
approach 0, while keeping negative distances above a certain threshold.
Args:
queries: [B, D] Anchor input tensor.
positives: [B, D] Positive sample input tensor.
negatives: [B, Nneg, D] Negative sample input tensor.
margin: Float contrastive loss loss margin.
eps: Float parameter for numerical stability.
Returns:
loss: Scalar tensor.
"""
D = tf.shape(queries)[1]
# Number of `queries`.
B = tf.shape(queries)[0]
# Number of `positives`.
np = tf.shape(positives)[0]
# Number of `negatives`.
Nneg = tf.shape(negatives)[1]
# Preparing negatives.
stacked_negatives = tf.reshape(negatives, [Nneg * B, D])
# Preparing queries for further loss calculation.
stacked_queries = tf.repeat(queries, Nneg + 1, axis=0)
positives_and_negatives = tf.concat([positives, stacked_negatives], axis=0)
# Calculate an Euclidean norm for each pair of points. For any positive
# pair of data points this distance should be small, and for
# negative pair it should be large.
distances = tf.norm(stacked_queries - positives_and_negatives + eps, axis=1)
positives_part = 0.5 * tf.pow(distances[:np], 2.0)
negatives_part = 0.5 * tf.pow(tf.math.maximum(margin - distances[np:], 0),
2.0)
# Final contrastive loss calculation.
loss = tf.reduce_sum(tf.concat([positives_part, negatives_part], 0))
return loss
def triplet_loss(queries, positives, negatives, margin=0.1):
"""Calculates Triplet Loss.
Triplet loss tries to keep all queries closer to positives than to any
negatives. Differently from the Contrastive Loss, Triplet Loss uses squared
distances when computing the loss.
Args:
queries: [B, D] Anchor input tensor.
positives: [B, D] Positive sample input tensor.
negatives: [B, Nneg, D] Negative sample input tensor.
margin: Float triplet loss loss margin.
Returns:
loss: Scalar tensor.
"""
D = tf.shape(queries)[1]
# Number of `queries`.
B = tf.shape(queries)[0]
# Number of `negatives`.
Nneg = tf.shape(negatives)[1]
# Preparing negatives.
stacked_negatives = tf.reshape(negatives, [Nneg * B, D])
# Preparing queries for further loss calculation.
stacked_queries = tf.repeat(queries, Nneg, axis=0)
# Preparing positives for further loss calculation.
stacked_positives = tf.repeat(positives, Nneg, axis=0)
# Computes *squared* distances.
distance_positives = tf.reduce_sum(
tf.square(stacked_queries - stacked_positives), axis=1)
distance_negatives = tf.reduce_sum(tf.square(stacked_queries -
stacked_negatives), axis=1)
# Final triplet loss calculation.
loss = tf.reduce_sum(tf.maximum(distance_positives -
distance_negatives + margin, 0.0))
return loss
# Copyright 2021 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for Ranking losses."""
import tensorflow as tf
from delf.python.training.losses import ranking_losses
class RankingLossesTest(tf.test.TestCase):
def testContrastiveLoss(self):
# Testing the correct numeric value.
queries = tf.math.l2_normalize(tf.constant([[1.0, 2.0, -2.0]]))
positives = tf.math.l2_normalize(tf.constant([[-1.0, 2.0, 0.0]]))
negatives = tf.math.l2_normalize(tf.constant([[[-5.0, 0.0, 3.0]]]))
result = ranking_losses.contrastive_loss(queries, positives, negatives,
margin=0.7, eps=1e-6)
exp_output = 0.55278635
self.assertAllClose(exp_output, result)
def testTripletLossZeroLoss(self):
# Testing the correct numeric value in case if query-positive distance is
# smaller than the query-negative distance.
queries = tf.math.l2_normalize(tf.constant([[1.0, 2.0, -2.0]]))
positives = tf.math.l2_normalize(tf.constant([[-1.0, 2.0, 0.0]]))
negatives = tf.math.l2_normalize(tf.constant([[[-5.0, 0.0, 3.0]]]))
result = ranking_losses.triplet_loss(queries, positives, negatives,
margin=0.1)
exp_output = 0.0
self.assertAllClose(exp_output, result)
def testTripletLossNonZeroLoss(self):
# Testing the correct numeric value in case if query-positive distance is
# bigger than the query-negative distance.
queries = tf.math.l2_normalize(tf.constant([[1.0, 2.0, -2.0]]))
positives = tf.math.l2_normalize(tf.constant([[-5.0, 0.0, 3.0]]))
negatives = tf.math.l2_normalize(tf.constant([[[-1.0, 2.0, 0.0]]]))
result = ranking_losses.triplet_loss(queries, positives, negatives,
margin=0.1)
exp_output = 2.2520838
self.assertAllClose(exp_output, result)
if __name__ == '__main__':
tf.test.main()
...@@ -29,6 +29,7 @@ from absl import logging ...@@ -29,6 +29,7 @@ from absl import logging
import h5py import h5py
import tensorflow as tf import tensorflow as tf
from delf.python.pooling_layers import pooling
layers = tf.keras.layers layers = tf.keras.layers
...@@ -54,23 +55,23 @@ class _IdentityBlock(tf.keras.Model): ...@@ -54,23 +55,23 @@ class _IdentityBlock(tf.keras.Model):
bn_axis = 1 if data_format == 'channels_first' else 3 bn_axis = 1 if data_format == 'channels_first' else 3
self.conv2a = layers.Conv2D( self.conv2a = layers.Conv2D(
filters1, (1, 1), name=conv_name_base + '2a', data_format=data_format) filters1, (1, 1), name=conv_name_base + '2a', data_format=data_format)
self.bn2a = layers.BatchNormalization( self.bn2a = layers.BatchNormalization(
axis=bn_axis, name=bn_name_base + '2a') axis=bn_axis, name=bn_name_base + '2a')
self.conv2b = layers.Conv2D( self.conv2b = layers.Conv2D(
filters2, filters2,
kernel_size, kernel_size,
padding='same', padding='same',
data_format=data_format, data_format=data_format,
name=conv_name_base + '2b') name=conv_name_base + '2b')
self.bn2b = layers.BatchNormalization( self.bn2b = layers.BatchNormalization(
axis=bn_axis, name=bn_name_base + '2b') axis=bn_axis, name=bn_name_base + '2b')
self.conv2c = layers.Conv2D( self.conv2c = layers.Conv2D(
filters3, (1, 1), name=conv_name_base + '2c', data_format=data_format) filters3, (1, 1), name=conv_name_base + '2c', data_format=data_format)
self.bn2c = layers.BatchNormalization( self.bn2c = layers.BatchNormalization(
axis=bn_axis, name=bn_name_base + '2c') axis=bn_axis, name=bn_name_base + '2c')
def call(self, input_tensor, training=False): def call(self, input_tensor, training=False):
x = self.conv2a(input_tensor) x = self.conv2a(input_tensor)
...@@ -118,34 +119,34 @@ class _ConvBlock(tf.keras.Model): ...@@ -118,34 +119,34 @@ class _ConvBlock(tf.keras.Model):
bn_axis = 1 if data_format == 'channels_first' else 3 bn_axis = 1 if data_format == 'channels_first' else 3
self.conv2a = layers.Conv2D( self.conv2a = layers.Conv2D(
filters1, (1, 1), filters1, (1, 1),
strides=strides, strides=strides,
name=conv_name_base + '2a', name=conv_name_base + '2a',
data_format=data_format) data_format=data_format)
self.bn2a = layers.BatchNormalization( self.bn2a = layers.BatchNormalization(
axis=bn_axis, name=bn_name_base + '2a') axis=bn_axis, name=bn_name_base + '2a')
self.conv2b = layers.Conv2D( self.conv2b = layers.Conv2D(
filters2, filters2,
kernel_size, kernel_size,
padding='same', padding='same',
name=conv_name_base + '2b', name=conv_name_base + '2b',
data_format=data_format) data_format=data_format)
self.bn2b = layers.BatchNormalization( self.bn2b = layers.BatchNormalization(
axis=bn_axis, name=bn_name_base + '2b') axis=bn_axis, name=bn_name_base + '2b')
self.conv2c = layers.Conv2D( self.conv2c = layers.Conv2D(
filters3, (1, 1), name=conv_name_base + '2c', data_format=data_format) filters3, (1, 1), name=conv_name_base + '2c', data_format=data_format)
self.bn2c = layers.BatchNormalization( self.bn2c = layers.BatchNormalization(
axis=bn_axis, name=bn_name_base + '2c') axis=bn_axis, name=bn_name_base + '2c')
self.conv_shortcut = layers.Conv2D( self.conv_shortcut = layers.Conv2D(
filters3, (1, 1), filters3, (1, 1),
strides=strides, strides=strides,
name=conv_name_base + '1', name=conv_name_base + '1',
data_format=data_format) data_format=data_format)
self.bn_shortcut = layers.BatchNormalization( self.bn_shortcut = layers.BatchNormalization(
axis=bn_axis, name=bn_name_base + '1') axis=bn_axis, name=bn_name_base + '1')
def call(self, input_tensor, training=False): def call(self, input_tensor, training=False):
x = self.conv2a(input_tensor) x = self.conv2a(input_tensor)
...@@ -222,23 +223,23 @@ class ResNet50(tf.keras.Model): ...@@ -222,23 +223,23 @@ class ResNet50(tf.keras.Model):
def conv_block(filters, stage, block, strides=(2, 2)): def conv_block(filters, stage, block, strides=(2, 2)):
return _ConvBlock( return _ConvBlock(
3, 3,
filters, filters,
stage=stage, stage=stage,
block=block, block=block,
data_format=data_format, data_format=data_format,
strides=strides) strides=strides)
def id_block(filters, stage, block): def id_block(filters, stage, block):
return _IdentityBlock( return _IdentityBlock(
3, filters, stage=stage, block=block, data_format=data_format) 3, filters, stage=stage, block=block, data_format=data_format)
self.conv1 = layers.Conv2D( self.conv1 = layers.Conv2D(
64, (7, 7), 64, (7, 7),
strides=(2, 2), strides=(2, 2),
data_format=data_format, data_format=data_format,
padding='same', padding='same',
name='conv1') name='conv1')
bn_axis = 1 if data_format == 'channels_first' else 3 bn_axis = 1 if data_format == 'channels_first' else 3
self.bn_conv1 = layers.BatchNormalization(axis=bn_axis, name='bn_conv1') self.bn_conv1 = layers.BatchNormalization(axis=bn_axis, name='bn_conv1')
self.max_pool = layers.MaxPooling2D((3, 3), self.max_pool = layers.MaxPooling2D((3, 3),
...@@ -288,14 +289,14 @@ class ResNet50(tf.keras.Model): ...@@ -288,14 +289,14 @@ class ResNet50(tf.keras.Model):
reduction_indices = tf.constant(reduction_indices) reduction_indices = tf.constant(reduction_indices)
if pooling == 'avg': if pooling == 'avg':
self.global_pooling = functools.partial( self.global_pooling = functools.partial(
tf.reduce_mean, axis=reduction_indices, keepdims=False) tf.reduce_mean, axis=reduction_indices, keepdims=False)
elif pooling == 'max': elif pooling == 'max':
self.global_pooling = functools.partial( self.global_pooling = functools.partial(
tf.reduce_max, axis=reduction_indices, keepdims=False) tf.reduce_max, axis=reduction_indices, keepdims=False)
elif pooling == 'gem': elif pooling == 'gem':
logging.info('Adding GeMPooling layer with power %f', gem_power) logging.info('Adding GeMPooling layer with power %f', gem_power)
self.global_pooling = functools.partial( self.global_pooling = functools.partial(
gem_pooling, axis=reduction_indices, power=gem_power) pooling.gem, axis=reduction_indices, power=gem_power)
else: else:
self.global_pooling = None self.global_pooling = None
if embedding_layer: if embedding_layer:
...@@ -456,27 +457,4 @@ class ResNet50(tf.keras.Model): ...@@ -456,27 +457,4 @@ class ResNet50(tf.keras.Model):
logging.info(weights) logging.info(weights)
else: else:
logging.info('Layer %s does not have inner layers.', logging.info('Layer %s does not have inner layers.',
layer.name) layer.name)
\ No newline at end of file
def gem_pooling(feature_map, axis, power, threshold=1e-6):
"""Performs GeM (Generalized Mean) pooling.
See https://arxiv.org/abs/1711.02512 for a reference.
Args:
feature_map: Tensor of shape [batch, height, width, channels] for
the "channels_last" format or [batch, channels, height, width] for the
"channels_first" format.
axis: Dimensions to reduce.
power: Float, GeM power parameter.
threshold: Optional float, threshold to use for activations.
Returns:
pooled_feature_map: Tensor of shape [batch, channels].
"""
return tf.pow(
tf.reduce_mean(tf.pow(tf.maximum(feature_map, threshold), power),
axis=axis,
keepdims=False),
1.0 / power)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment