Unverified Commit c1b65fc1 authored by srihari-humbarwadi's avatar srihari-humbarwadi
Browse files

added docstrings and license header

parent 5803f451
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Contains definition for bilinear grid sampling and mask pasting layers."""
from typing import List
import tensorflow as tf
class BilinearGridSampler(tf.keras.layers.Layer):
def __init__(self, align_corners, **kwargs):
""" Bilinear Grid Sampling layer."""
def __init__(self, align_corners: bool = False, **kwargs):
"""Generates panoptic segmentation masks.
Args:
align_corners: A `bool` bool, if True, the centers of the 4 corner
pixels of the input and output tensors are aligned, preserving the
values at the corner pixels.
"""
super(BilinearGridSampler, self).__init__(**kwargs)
self.align_corners = align_corners
......@@ -25,7 +52,7 @@ class BilinearGridSampler(tf.keras.layers.Layer):
tf.greater_equal(y_coord, 0)),
tf.logical_and(
tf.less(x_coord, self._width),
tf.less(y_coord, self._width)))
tf.less(y_coord, self._height)))
def _get_pixel(self, features, x_coord, y_coord):
x_coord = tf.cast(x_coord, dtype=tf.int32)
......@@ -97,8 +124,18 @@ class BilinearGridSampler(tf.keras.layers.Layer):
class PasteMasks(tf.keras.layers.Layer):
def __init__(self, output_size, grid_sampler, **kwargs):
"""Layer to paste instance masks."""
def __init__(self, output_size: List[int],
grid_sampler, **kwargs):
"""Generates panoptic segmentation masks.
Args:
output_size: A `List` of integers that represent the height and width of
the output mask.
grid_sampler: A grid sampling layer. Currently only `BilinearGridSampler`
is supported.
"""
super(PasteMasks, self).__init__(**kwargs)
self._output_size = output_size
self._grid_sampler = grid_sampler
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment