Commit 32e4ca51 authored by qianyj's avatar qianyj
Browse files

Update code to v2.11.0

parents 9485aa1d 71060f67
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "xBH8CcrkV3IU"
},
"outputs": [],
"source": [
"#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
"# you may not use this file except in compliance with the License.\n",
"# You may obtain a copy of the License at\n",
"#\n",
"# https://www.apache.org/licenses/LICENSE-2.0\n",
"#\n",
"# Unless required by applicable law or agreed to in writing, software\n",
"# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
"# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
"# See the License for the specific language governing permissions and\n",
"# limitations under the License."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "9CzbXNRovpbc"
},
"source": [
"# Crown-of-Thorns Starfish Detection Pipeline"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Lpb0yoNjiWhw"
},
"source": [
"\u003ctable class=\"tfo-notebook-buttons\" align=\"left\"\u003e\n",
" \u003ctd\u003e\n",
" \u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/models/blob/master/official/projects/cots_detector/crown_of_thorns_starfish_detection_pipeline.ipynb?force_crab_mode=1\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e\n",
" \u003c/td\u003e\n",
" \u003ctd\u003e\n",
" \u003ca target=\"_blank\" href=\"https://github.com/tensorflow/models/blob/master/official/projects/cots_detector/crown_of_thorns_starfish_detection_pipeline.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView on GitHub\u003c/a\u003e\n",
" \u003c/td\u003e\n",
"\u003c/table\u003e"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "GUQ1x137ysLD"
},
"source": [
"Coral reefs are some of the most diverse and important ecosystems in the world , however they face a number of rising threats that have resulted in massive global declines. In Australia, outbreaks of the coral-eating crown-of-thorns starfish (COTS) have been shown to cause major coral loss, with just 15 starfish in a hectare being able to strip a reef of 90% of its coral tissue. While COTS naturally exist in the Indo-Pacific, overfishing and excess run-off nutrients have led to massive outbreaks that are devastating already vulnerable coral communities.\n",
"\n",
"Controlling COTS populations is critical to promoting coral growth and resilience, so Google teamed up with Australia’s national science agency, [CSIRO](https://www.csiro.au/en/), to tackle this problem. We trained ML object detection models to help scale underwater surveys, enabling the monitoring and mapping out these harmful invertebrates with the ultimate goal of helping control teams to address and prioritize outbreaks."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "jDiIX2xawkJw"
},
"source": [
"## About this notebook\n",
"\n",
"This notebook tutorial shows how to detect COTS using a pre-trained COTS detector implemented in TensorFlow. On top of just running the model on each frame of the video, the tracking code in this notebook aligns detections from frame to frame creating a consistent track for each COTS. Each track is given an id and frame count. Here is an example image from a video of a reef showing labeled COTS starfish.\n",
"\n",
"\u003cimg src=\"https://storage.googleapis.com/download.tensorflow.org/data/cots_detection/COTS_detected_sample.png\"\u003e"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "YxCF1t-Skag8"
},
"source": [
"It is recommended to enable GPU to accelerate the inference. On CPU, this runs for about 40 minutes, but on GPU it takes only 10 minutes. (In Colab it should already be set to GPU in the Runtime menu: *Runtime \u003e Change runtime type \u003e Hardware accelerator \u003e select \"GPU\"*)."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "a4R2T97u442o"
},
"source": [
"## Setup \n",
"\n",
"Install all needed packages."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "5Gs7XvCGlwlj"
},
"outputs": [],
"source": [
"# remove the existing datascience package to avoid package conflicts in the colab environment\n",
"!pip3 uninstall -y datascience\n",
"!pip3 install -q opencv-python\n",
"!pip3 install PILLOW"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "w-UQ87240x5R"
},
"outputs": [],
"source": [
"# Imports\n",
"import base64\n",
"import copy\n",
"import dataclasses\n",
"import glob\n",
"import logging\n",
"import mimetypes\n",
"import os\n",
"import pathlib\n",
"import subprocess\n",
"import time\n",
"import textwrap\n",
"from typing import Dict, Iterable, List, Optional, Tuple\n",
"\n",
"from absl import logging as absl_logging\n",
"from IPython import display\n",
"import cv2\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"import PIL.Image\n",
"import tensorflow as tf\n",
"from tqdm import tqdm"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "gsSclJg4sJbX"
},
"source": [
"Define all needed variables."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "iKMCvnZEXBBT"
},
"outputs": [],
"source": [
"model_name = \"cots_1080_v1\" #@param [\"cots_1080_v1\", \"cots_720_v1\"]\n",
"test_sequence_name = \"test3\" #@param [\"test1\", \"test2\", \"test3\"]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "ORLJSdLq4-gd"
},
"outputs": [],
"source": [
"cots_model = f\"https://storage.googleapis.com/download.tensorflow.org/models/cots_detection/{model_name}.zip\"\n",
"\n",
"# Alternatively, this dataset can be downloaded through CSIRO's Data Access Portal at https://data.csiro.au/collection/csiro:54830v2\n",
"sample_data_link = f\"https://storage.googleapis.com/download.tensorflow.org/data/cots_detection/sample_images.zip\"\n",
"\n",
"preview_video_path = \"preview.mp4\"\n",
"detection_small_video_path = \"COTS_detection.mp4\"\n",
"detection_csv_path = \"detections.csv\""
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "FNwP3s-5xgaF"
},
"source": [
"You also need to retrieve the sample data. This sample data is made up of a series of chronological images."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "DF_c_ZMXdPRN"
},
"outputs": [],
"source": [
"sample_data_path = tf.keras.utils.get_file(origin=sample_data_link)\n",
"# Unzip data\n",
"!mkdir sample_images\n",
"!unzip -o -q {sample_data_path} -d sample_images"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Ghf-4E5-ZiJn"
},
"source": [
"Convert the images to a video file:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "kCdWsbO1afIJ"
},
"outputs": [],
"source": [
"tmp_video_path = \"tmp_preview.mp4\"\n",
"\n",
"filenames = sorted(glob.glob(f\"sample_images/{test_sequence_name}/*.jpg\"))\n",
"img = cv2.imread(filenames[0])\n",
"height, width, layers = img.shape\n",
"size = (width, height)\n",
"\n",
"video_writer = cv2.VideoWriter(\n",
" filename=tmp_video_path,\n",
" fourcc=cv2.VideoWriter_fourcc(*\"MP4V\"), \n",
" fps=15, \n",
" frameSize=size)\n",
" \n",
"for filename in tqdm(filenames):\n",
" img = cv2.imread(filename)\n",
" video_writer.write(img)\n",
"cv2.destroyAllWindows()\n",
"video_writer.release()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "cHsKpPyviWmF"
},
"source": [
"Re-encode the video, and reduce its size (Colab crashes if you try to embed the full size video)."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "_li0qe-gh1iT"
},
"outputs": [],
"source": [
"subprocess.check_call([\n",
" \"ffmpeg\", \"-y\", \"-i\", tmp_video_path,\n",
" \"-vf\",\"scale=800:-1\",\n",
" \"-crf\", \"18\",\n",
" \"-preset\", \"veryfast\",\n",
" \"-vcodec\", \"libx264\", preview_video_path])"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "2ItoiHyYQGya"
},
"source": [
"The images you downloaded are frames of a movie showing a top view of a coral reef with crown-of-thorns starfish. Use the `base64` data-URL trick to embed the video in this notebook:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "u0fqXQUzdZCu"
},
"outputs": [],
"source": [
"def embed_video_file(path: os.PathLike) -\u003e display.HTML:\n",
" \"\"\"Embeds a file in the notebook as an html tag with a data-url.\"\"\"\n",
" path = pathlib.Path(path)\n",
" mime, unused_encoding = mimetypes.guess_type(str(path))\n",
" data = path.read_bytes()\n",
"\n",
" b64 = base64.b64encode(data).decode()\n",
" return display.HTML(\n",
" textwrap.dedent(\"\"\"\n",
" \u003cvideo width=\"640\" height=\"480\" controls\u003e\n",
" \u003csource src=\"data:{mime};base64,{b64}\" type=\"{mime}\"\u003e\n",
" Your browser does not support the video tag.\n",
" \u003c/video\u003e\n",
" \"\"\").format(mime=mime, b64=b64))\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "SiOsbr8xePkg"
},
"outputs": [],
"source": [
"embed_video_file(preview_video_path)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "9Z0DTbWrZMZ-"
},
"source": [
"Can you se them? there are lots. The goal of the model is to put boxes around all of the starfish. Each starfish will get its own ID, and that ID will be stable as the camera passes over it."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "d0iALUwM0g2p"
},
"source": [
"## Load the model"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "fVq6vNBTxM62"
},
"source": [
"Download the trained COTS detection model that matches your preferences from earlier."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "No5jRA1TxXj0"
},
"outputs": [],
"source": [
"model_path = tf.keras.utils.get_file(origin=cots_model)\n",
"# Unzip model\n",
"!mkdir {model_name}\n",
"!unzip -o -q {model_path} -d {model_name}"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ezyuSHK5ap__"
},
"source": [
"Load trained model from disk and create the inference function `model_fn()`. This might take a little while."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "HXQnNjwl8Beu"
},
"outputs": [],
"source": [
"absl_logging.set_verbosity(absl_logging.ERROR)\n",
"\n",
"tf.config.optimizer.set_experimental_options({'auto_mixed_precision': True})\n",
"tf.config.optimizer.set_jit(True)\n",
"\n",
"model_fn = tf.saved_model.load(model_name).signatures['serving_default']"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "OvLuznhUa7uG"
},
"source": [
"Here's one test image; how many COTS can you see?"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "XmQF_2L_a7Hu"
},
"outputs": [],
"source": [
"example_frame_number = 52\n",
"image = tf.io.read_file(filenames[example_frame_number])\n",
"image = tf.io.decode_jpeg(image)\n",
"\n",
"# Caution PIL and tf use \"RGB\" color order, while cv2 uses \"BGR\".\n",
"PIL.Image.fromarray(image.numpy())"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "KSOf4V8WhTHF"
},
"source": [
"## Raw model outputs\n",
"\n",
"Try running the model on the image. The model expects a batch of images so add an outer `batch` dimension before calling the model.\n",
"\n",
"Note: The model only runs correctly with a batch size of 1.\n",
"\n",
"The result is a dictionary with a number of fields. For all fields the first dimension of the shape is the `batch` dimension, "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "iqLHo8h0c2pW"
},
"outputs": [],
"source": [
"image_batch = image[tf.newaxis, ...]\n",
"result = model_fn(image_batch)\n",
"\n",
"print(f\"{'image_batch':20s}- shape: {image_batch.shape}\")\n",
"\n",
"for key, value in result.items():\n",
" print(f\"{key:20s}- shape: {value.shape}\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "0xuNoKLCjyDz"
},
"source": [
"The `num_detections` field gives the number of valid detections, but this is always 100. There are always 100 locations that _could_ be a COTS."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "nGCDZJQvkIOL"
},
"outputs": [],
"source": [
"print('\\nnum_detections: ', result['num_detections'].numpy())"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "cSd7JJYqkPz7"
},
"source": [
"Similarly the `detection_classes` field is always `0`, since the model only detects 1 class: COTS."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "JoY8bJrfkcuS"
},
"outputs": [],
"source": [
"print('detection_classes: \\n', result['detection_classes'].numpy())"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "X2nVLSOokyog"
},
"source": [
"What actually matters here is the detection scores, indicating the quality of each detection: "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "iepEgCc2jsRD"
},
"outputs": [],
"source": [
"result['detection_scores'].numpy()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Fn2B0nbplAFy"
},
"source": [
"You need to choose a threshold that determines what counts as a good detection. This frame has a few good detections:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "a30Uyc0WlK2a"
},
"outputs": [],
"source": [
"good_detections = result['detection_scores'] \u003e 0.4\n",
"good_detections.numpy()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Y_xrbQiAlWrK"
},
"source": [
"## Bounding boxes and detections\n",
"\n",
"Build a class to handle the detection boxes:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "S5inzqu-4JhT"
},
"outputs": [],
"source": [
"@dataclasses.dataclass(frozen=True)\n",
"class BBox:\n",
" x0: float\n",
" y0: float\n",
" x1: float\n",
" y1: float\n",
"\n",
" def replace(self, **kwargs):\n",
" d = self.__dict__.copy()\n",
" d.update(kwargs)\n",
" return type(self)(**d)\n",
"\n",
" @property\n",
" def center(self)-\u003e Tuple[float, float]:\n",
" return ((self.x0+self.x1)/2, (self.y0+self.y1)/2)\n",
" \n",
" @property\n",
" def width(self) -\u003e float:\n",
" return self.x1 - self.x0\n",
"\n",
" @property\n",
" def height(self) -\u003e float:\n",
" return self.y1 - self.y0\n",
"\n",
" @property\n",
" def area(self)-\u003e float:\n",
" return (self.x1 - self.x0 + 1) * (self.y1 - self.y0 + 1)\n",
" \n",
" def intersection(self, other)-\u003e Optional['BBox']:\n",
" x0 = max(self.x0, other.x0)\n",
" y0 = max(self.y0, other.y0)\n",
" x1 = min(self.x1, other.x1)\n",
" y1 = min(self.y1, other.y1)\n",
" if x0 \u003e x1 or y0 \u003e y1:\n",
" return None\n",
" return BBox(x0, y0, x1, y1)\n",
"\n",
" def iou(self, other):\n",
" intersection = self.intersection(other)\n",
" if intersection is None:\n",
" return 0\n",
" \n",
" ia = intersection.area\n",
"\n",
" return ia/(self.area + other.area - ia)\n",
" \n",
" def draw(self, image, label=None, color=(0, 140, 255)):\n",
" image = np.asarray(image)\n",
" cv2.rectangle(image, \n",
" (int(self.x0), int(self.y0)),\n",
" (int(self.x1), int(self.y1)),\n",
" color,\n",
" thickness=2)\n",
" if label is not None:\n",
" cv2.putText(image, str(label), \n",
" (int(self.x0), int(self.y0-10)),\n",
" cv2.FONT_HERSHEY_SIMPLEX,\n",
" 0.9, color, thickness=2)\n",
" return image"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "2izYMR9Q6Dn0"
},
"source": [
"And a class to represent a `Detection`, with a method to create a list of detections from the model's output:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "tybwY3eaY803"
},
"outputs": [],
"source": [
"@dataclasses.dataclass(frozen=True)\n",
"class Detection:\n",
" \"\"\"Detection dataclass.\"\"\"\n",
" class_id: int\n",
" score: float\n",
" bbox: BBox\n",
" threshold:float = 0.4\n",
"\n",
" def replace(self, **kwargs):\n",
" d = self.__dict__.copy()\n",
" d.update(kwargs)\n",
" return type(self)(**d)\n",
"\n",
" @classmethod\n",
" def process_model_output(\n",
" cls, image, detections: Dict[str, tf.Tensor]\n",
" ) -\u003e Iterable['Detection']:\n",
" \n",
" # The model only works on a batch size of 1.\n",
" detection_boxes = detections['detection_boxes'].numpy()[0]\n",
" detection_classes = detections['detection_classes'].numpy()[0].astype(np.int32)\n",
" detection_scores = detections['detection_scores'].numpy()[0]\n",
"\n",
" img_h, img_w = image.shape[0:2]\n",
"\n",
" valid_indices = detection_scores \u003e= cls.threshold\n",
" classes = detection_classes[valid_indices]\n",
" scores = detection_scores[valid_indices]\n",
" boxes = detection_boxes[valid_indices, :]\n",
" detections = []\n",
"\n",
" for class_id, score, box in zip(classes, scores, boxes):\n",
" detections.append(\n",
" Detection(\n",
" class_id=class_id,\n",
" score=score,\n",
" bbox=BBox(\n",
" x0=box[1] * img_w,\n",
" y0=box[0] * img_h,\n",
" x1=box[3] * img_w,\n",
" y1=box[2] * img_h,)))\n",
"\n",
" return detections"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "QRZ9Q5meHl84"
},
"source": [
"## Preview some detections\n",
"\n",
"Now you can preview the model's output:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Px7AoFCn-psx"
},
"outputs": [],
"source": [
"detections = Detection.process_model_output(image, result)\n",
"\n",
"for n, det in enumerate(detections):\n",
" det.bbox.draw(image, label=n+1, color=(255, 140, 0))\n",
"\n",
"PIL.Image.fromarray(image.numpy())"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "B1q_n1xJLm60"
},
"source": [
"That works well for one frame, but to count the number of COTS in a video you'll need to track the detections from frame to frame. The raw detection indices are not stable, they're just sorted by the detection score. Below both sets of detections are overlaid on the second image with the first frame's detections in white and the second frame's in orange, the indices are not aligned. The positions are shifted because of camera motion between the two frames:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "PLtxJFPuLma0"
},
"outputs": [],
"source": [
"image2 = tf.io.read_file(filenames[example_frame_number+5]) # five frames later\n",
"image2 = tf.io.decode_jpeg(image2)\n",
"result2 = model_fn(image2[tf.newaxis, ...])\n",
"detections2 = Detection.process_model_output(image2, result2)\n",
"\n",
"for n, det in enumerate(detections):\n",
" det.bbox.draw(image2, label=n+1, color=(255, 255, 255))\n",
"\n",
"for n, det in enumerate(detections2):\n",
" det.bbox.draw(image2, label=n+1, color=(255, 140, 0))\n",
"\n",
"PIL.Image.fromarray(image2.numpy())"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "CoRxLon5MZ35"
},
"source": [
"## Use optical flow to align detections\n",
"\n",
"The two sets of bounding boxes above don't line up because of camera movement. \n",
"To see in more detail how tracks are aligned, initialize the tracker with the first image, and then run the optical flow step, `propagate_tracks`. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "wb_nkcPJJx2t"
},
"outputs": [],
"source": [
"def default_of_params():\n",
" its=20\n",
" eps=0.03\n",
" return {\n",
" 'winSize': (64,64),\n",
" 'maxLevel': 3,\n",
" 'criteria': (cv2.TermCriteria_COUNT + cv2.TermCriteria_EPS, its, eps)\n",
" }"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "mHVPymG8F2ke"
},
"outputs": [],
"source": [
"def propagate_detections(detections, image1, image2, of_params=None):\n",
" if of_params is None:\n",
" of_params = default_of_params()\n",
"\n",
" bboxes = [det.bbox for det in detections]\n",
" centers = np.float32([[bbox.center for bbox in bboxes]])\n",
" widths = np.float32([[bbox.width for bbox in bboxes]])\n",
" heights = np.float32([[bbox.height for bbox in bboxes]])\n",
"\n",
"\n",
" new_centers, status, error = cv2.calcOpticalFlowPyrLK(\n",
" image1, image2, centers, None, **of_params)\n",
"\n",
" x0s = new_centers[...,0] - widths/2\n",
" x1s = new_centers[...,0] + widths/2\n",
" y0s = new_centers[...,1] - heights/2\n",
" y1s = new_centers[...,1] + heights/2\n",
"\n",
" updated_detections = []\n",
" for i, det in enumerate(detections):\n",
" det = det.replace(\n",
" bbox = BBox(x0s[0,i], y0s[0,i], x1s[0,i], y1s[0,i]))\n",
" updated_detections.append(det)\n",
" return updated_detections"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "dCjgvoZnOcBu"
},
"source": [
"Now keep the white boxes for the initial detections, and the orange boxes for the new set of detections. But add the optical-flow propagated tracks in green. You can see that by using optical-flow to propagate the old detections to the new frame the alignment is quite good. It's this alignment between the old and new detections (between the green and orange boxes) that allows the tracker to make a persistent track for each COTS. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "aeTny8YnHwTw"
},
"outputs": [],
"source": [
"image = tf.io.read_file(filenames[example_frame_number])\n",
"image = tf.io.decode_jpeg(image).numpy()\n",
"image_gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)\n",
"\n",
"image2 = tf.io.read_file(filenames[example_frame_number+5]) # five frames later\n",
"image2 = tf.io.decode_jpeg(image2).numpy()\n",
"image2_gray = cv2.cvtColor(image2, cv2.COLOR_BGR2GRAY)\n",
"\n",
"updated_detections = propagate_detections(detections, image_gray, image2_gray)\n",
"\n",
"\n",
"for det in detections:\n",
" det.bbox.draw(image2, color=(255, 255, 255))\n",
"\n",
"for det in updated_detections:\n",
" det.bbox.draw(image2, color=(0, 255, 0))\n",
"\n",
"for det in detections2:\n",
" det.bbox.draw(image2, color=(255, 140, 0))\n",
"\n",
"PIL.Image.fromarray(image2)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "jbZ-7ICCENWG"
},
"source": [
"## Define **OpticalFlowTracker** class\n",
"\n",
"These help track the movement of each COTS object across the video frames.\n",
"\n",
"The tracker collects related detections into `Track` objects. \n",
"\n",
"The class's init is defined below, it's methods are defined in the following cells.\n",
"\n",
"The `__init__` method just initializes the track counter (`track_id`), and sets some default values for the tracking and optical flow configurations. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "3j2Ka1uGEoz4"
},
"outputs": [],
"source": [
"class OpticalFlowTracker:\n",
" \"\"\"Optical flow tracker.\"\"\"\n",
"\n",
" @classmethod\n",
" def add_method(cls, fun):\n",
" \"\"\"Attach a new method to the class.\"\"\"\n",
" setattr(cls, fun.__name__, fun)\n",
"\n",
"\n",
" def __init__(self, tid=1, ft=3.0, iou=0.5, tt=2.0, bb=32, of_params=None):\n",
" # Bookkeeping for the tracks.\n",
" # The running track count, incremented for each new track.\n",
" self.track_id = tid\n",
" self.tracks = []\n",
" self.prev_image = None\n",
" self.prev_time = None\n",
"\n",
" # Configuration for the track cleanup logic.\n",
" # How long to apply optical flow tracking without getting positive \n",
" # detections (sec).\n",
" self.track_flow_time = ft * 1000\n",
" # Required IoU overlap to link a detection to a track.\n",
" self.overlap_threshold = iou\n",
" # Used to detect if detector needs to be reset.\n",
" self.time_threshold = tt * 1000\n",
" self.border = bb\n",
"\n",
" if of_params is None:\n",
" of_params = default_of_params()\n",
" self.of_params = of_params\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "yBLSv0Fi_JJD"
},
"source": [
"Internally the tracker will use small `Track` and `Tracklet` classes to organize the data. The `Tracklet` class is just a `Detection` with a timestamp, while a `Track` is a track ID, the most recent detection and a list of `Tracklet` objects forming the history of the track."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "gCQFfAkaY_WN"
},
"outputs": [],
"source": [
"@dataclasses.dataclass(frozen=True)\n",
"class Tracklet:\n",
" timestamp:float\n",
" detection:Detection\n",
"\n",
" def replace(self, **kwargs):\n",
" d = self.__dict__.copy()\n",
" d.update(kwargs)\n",
" return type(self)(**d)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "7qVW1a_YZBgL"
},
"outputs": [],
"source": [
"@dataclasses.dataclass(frozen=True)\n",
"class Track:\n",
" \"\"\"Tracker entries.\"\"\"\n",
" id:int\n",
" det: Detection\n",
" linked_dets:List[Tracklet] = dataclasses.field(default_factory=list)\n",
"\n",
" def replace(self, **kwargs):\n",
" d = self.__dict__.copy()\n",
" d.update(kwargs)\n",
" return type(self)(**d)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Ntl_4oUp_1nD"
},
"source": [
"The tracker keeps a list of active `Track` objects.\n",
"\n",
"The main `update` method takes an image, along with the list of detections and the timestamp for that image. On each frame step it performs the following sub-tasks:\n",
"\n",
"* The tracker uses optical flow to calculate where each `Track` expects to see a new `Detection`.\n",
"* The tracker matches up the actual detections for the frame to the expected detections for each Track.\n",
"* If a detection doesn't get matched to an existing track, a new track is created for the detection.\n",
"* If a track stops getting assigned new detections, it is eventually deactivated. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "koZ0mjFTpiTv"
},
"outputs": [],
"source": [
"@OpticalFlowTracker.add_method\n",
"def update(self, image_bgr, detections, timestamp):\n",
" start = time.time()\n",
"\n",
" image = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2GRAY)\n",
"\n",
" # Remove dead tracks.\n",
" self.tracks = self.cleanup_tracks(image, timestamp)\n",
"\n",
" # Run optical flow to update existing tracks.\n",
" if self.prev_time is not None:\n",
" self.tracks = self.propagate_tracks(image)\n",
"\n",
" # Update the track list based on the new detections\n",
" self.apply_detections_to_tracks(image, detections, timestamp)\n",
"\n",
" self.prev_image = image\n",
" self.prev_time = timestamp\n",
"\n",
" return self.tracks"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "U-6__zF2CHFS"
},
"source": [
"The `cleanup_tracks` method clears tracks that are too old or are too close to the edge of the image."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "HQBj8GihjF3-"
},
"outputs": [],
"source": [
"@OpticalFlowTracker.add_method\n",
"def cleanup_tracks(self, image, timestamp) -\u003e List[Track]:\n",
" image_w = image.shape[1]\n",
" image_h = image.shape[0]\n",
"\n",
" # Assume tracker is invalid if too much time has passed!\n",
" if (self.prev_time is not None and\n",
" timestamp - self.prev_time \u003e self.time_threshold):\n",
" logging.info(\n",
" 'Too much time since last update, resetting tracker.')\n",
" return []\n",
"\n",
" # Remove tracks which are:\n",
" # - Touching the image edge.\n",
" # - Have existed for a long time without linking a real detection.\n",
" active_tracks = []\n",
" for track in self.tracks:\n",
" bbox = track.det.bbox\n",
" if (bbox.x0 \u003c self.border or bbox.y0 \u003c self.border or\n",
" bbox.x1 \u003e= (image_w - self.border) or\n",
" bbox.y1 \u003e= (image_h - self.border)):\n",
" logging.info(f'Removing track {track.id} because it\\'s near the border')\n",
" continue\n",
"\n",
" time_since_last_detection = timestamp - track.linked_dets[-1].timestamp\n",
" if (time_since_last_detection \u003e self.track_flow_time):\n",
" logging.info(f'Removing track {track.id} because it\\'s too old '\n",
" f'({time_since_last_detection:.02f}s)')\n",
" continue\n",
"\n",
" active_tracks.append(track)\n",
"\n",
" return active_tracks"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "DVzNcESxC6vY"
},
"source": [
"The `propagate_tracks` method uses optical flow to update each track's bounding box's position to predict their location in the new image: "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "0GycdAflCs6v"
},
"outputs": [],
"source": [
"@OpticalFlowTracker.add_method\n",
"def propagate_tracks(self, image):\n",
" if not self.tracks:\n",
" return self.tracks[:]\n",
"\n",
" detections = [track.det for track in self.tracks]\n",
" detections = propagate_detections(detections, self.prev_image, image, self.of_params)\n",
"\n",
" return [track.replace(det=det) \n",
" for track, det in zip(self.tracks, detections)]\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "uLbVeetwD0ph"
},
"source": [
"The `apply_detections_to_tracks` method compares each detection to the updated bounding box for each track. The detection is added to the track that matches best, if the match is better than the `overlap_threshold`. If no track is better than the threshold, the detection is used to create a new track. \n",
"\n",
"If a track has no new detection assigned to it the predicted detection is used."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "j6pfRhDRlApe"
},
"outputs": [],
"source": [
"@OpticalFlowTracker.add_method\n",
"def apply_detections_to_tracks(self, image, detections, timestamp):\n",
" image_w = image.shape[1]\n",
" image_h = image.shape[0]\n",
"\n",
" # Insert new detections.\n",
" detected_obj_track_ids = set()\n",
"\n",
" for detection in detections:\n",
" bbox = detection.bbox\n",
" if (bbox.x0 \u003c self.border or bbox.y0 \u003c self.border or\n",
" bbox.x1 \u003e= image_w - self.border or\n",
" bbox.y1 \u003e= image_h - self.border):\n",
" logging.debug('Skipping detection because it\\'s close to the border.')\n",
" continue\n",
"\n",
" # See if detection can be linked to an existing track.\n",
" linked = False\n",
" overlap_index = 0\n",
" overlap_max = -1000\n",
" for track_index, track in enumerate(self.tracks):\n",
" logging.debug('Testing track %d', track_index)\n",
" if track.det.class_id != detection.class_id:\n",
" continue\n",
" overlap = detection.bbox.iou(track.det.bbox)\n",
" if overlap \u003e overlap_max:\n",
" overlap_index = track_index\n",
" overlap_max = overlap\n",
"\n",
" # Link to existing track with maximal IoU.\n",
" if overlap_max \u003e self.overlap_threshold:\n",
" track = self.tracks[overlap_index]\n",
" self.tracks[overlap_index] = track.replace(det=detection)\n",
" track.linked_dets.append(Tracklet(timestamp, detection))\n",
" detected_obj_track_ids.add(track.id)\n",
" linked = True\n",
"\n",
" if not linked:\n",
" logging.info(f'Creating new track with ID {self.track_id}')\n",
" new_track = Track(self.track_id, detection)\n",
" new_track.linked_dets.append(Tracklet(timestamp, detection))\n",
" detected_obj_track_ids.add(self.track_id)\n",
" self.tracks.append(new_track)\n",
" self.track_id += 1\n",
"\n",
" for track in self.tracks:\n",
" # If the detector does not find the obj but estimated in the tracker, \n",
" # add the estimated one to that tracker's linked_dets\n",
" if track.id not in detected_obj_track_ids:\n",
" track.linked_dets.append(Tracklet(timestamp, track.det))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "gY0AH-KUHPlC"
},
"source": [
"## Test run the tracker\n",
"\n",
"So reload the test images, and run the detections to test out the tracker.\n",
"\n",
"On the first frame it creates and returns one track per detection:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "7Ekkj_XFGdfq"
},
"outputs": [],
"source": [
"example_frame_number = 52\n",
"image = tf.io.read_file(filenames[example_frame_number])\n",
"image = tf.io.decode_jpeg(image)\n",
"result = model_fn(image[tf.newaxis, ...])\n",
"detections = Detection.process_model_output(image, result)\n",
"\n",
"tracker = OpticalFlowTracker()\n",
"tracks = tracker.update(image.numpy(), detections, timestamp = 0)\n",
"\n",
"print(f'detections : {len(detections)}') \n",
"print(f'tracks : {len(tracks)}')"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "WovDYdNMII-n"
},
"source": [
"On the second frame many of the detections get assigned to existing tracks:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "7iFEKwgMGi5n"
},
"outputs": [],
"source": [
"image2 = tf.io.read_file(filenames[example_frame_number+5]) # five frames later\n",
"image2 = tf.io.decode_jpeg(image2)\n",
"result2 = model_fn(image2[tf.newaxis, ...])\n",
"detections2 = Detection.process_model_output(image2, result2)\n",
"\n",
"new_tracks = tracker.update(image2.numpy(), detections2, timestamp = 1000)\n",
"\n",
"print(f'detections : {len(detections2)}') \n",
"print(f'tracks : {len(new_tracks)}')"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "dbkedwiVrxnQ"
},
"source": [
"Now the track IDs should be consistent:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "QexJR5gerw6q"
},
"outputs": [],
"source": [
"test_img = image2.numpy()\n",
"for n,track in enumerate(tracks):\n",
" track.det.bbox.draw(test_img, label=n, color=(255, 255, 255))\n",
"\n",
"for n,track in enumerate(new_tracks):\n",
" track.det.bbox.draw(test_img, label=n, color=(255, 140, 0))\n",
"\n",
"PIL.Image.fromarray(test_img)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "OW5gGixy1osE"
},
"source": [
"## Perform the COTS detection inference and tracking."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "f21596933d08"
},
"source": [
"The main tracking loop will perform the following: \n",
"\n",
"1. Load the images in order.\n",
"2. Run the model on the image.\n",
"3. Update the tracker with the new images and detections.\n",
"4. Keep information about each track (id, current index and length) analysis or display. \n",
"\n",
"The `TrackAnnotation` class, below, will collect the data about each track:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "lESJE0qXxubm"
},
"outputs": [],
"source": [
"@dataclasses.dataclass(frozen=True)\n",
"class TrackAnnotation:\n",
" det: Detection\n",
" seq_id: int\n",
" seq_idx: int\n",
" seq_length: Optional[int] = None\n",
"\n",
" def replace(self, **kwargs):\n",
" d = self.__dict__.copy()\n",
" d.update(kwargs)\n",
" return type(self)(**d)\n",
"\n",
" def annotation_str(self):\n",
" return f\"{self.seq_id} ({self.seq_idx}/{self.seq_length})\"\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "3863fb28cd34"
},
"source": [
"The `parse_image` function, below, will take `(index, filename)` pairs load the images as tensors and return `(timestamp_ms, filename, image)` triples, assuming 30fps"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Dn7efhr0GBGz"
},
"outputs": [],
"source": [
"# Read a jpg image and decode it to a uint8 tf tensor.\n",
"def parse_image(index, filename):\n",
" image = tf.io.read_file(filename)\n",
" image = tf.io.decode_jpeg(image)\n",
" timestamp_ms = 1000*index/30 # assuming 30fps\n",
" return (timestamp_ms, filename, image)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "8f878e4b0852"
},
"source": [
"Here is the main tracker loop. Note that initially the saved `TrackAnnotations` don't contain the track lengths. The lengths are collected in the `track_length_for_id` dict."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "cqN8RGBgVbr4"
},
"outputs": [],
"source": [
"# Create a tracker object\n",
"tracker = OpticalFlowTracker(tid=1)\n",
"# Record tracking responses from the tracker\n",
"detection_result = []\n",
"# Record the length of each tracking sequence\n",
"track_length_for_id = {}\n",
"\n",
"# Create a data loader\n",
"file_list = sorted(glob.glob(f\"sample_images/{test_sequence_name}/*.jpg\"))\n",
"list_ds = tf.data.Dataset.from_tensor_slices(file_list).enumerate()\n",
"images_ds = list_ds.map(parse_image)\n",
"\n",
"# Traverse the dataset with batch size = 1, you cannot change the batch size\n",
"for timestamp_ms, file_path, images in tqdm(images_ds.batch(1, drop_remainder=True)):\n",
" # get detection result\n",
" detections = Detection.process_model_output(images[0], model_fn(images))\n",
"\n",
" # Feed detection results and the corresponding timestamp to the tracker, and then get tracker response\n",
" tracks = tracker.update(images[0].numpy(), detections, timestamp_ms[0])\n",
" annotations = []\n",
" for track in tracks:\n",
" anno = TrackAnnotation(\n",
" det=track.det,\n",
" seq_id = track.id,\n",
" seq_idx = len(track.linked_dets)\n",
" )\n",
" annotations.append(anno)\n",
" track_length_for_id[track.id] = len(track.linked_dets)\n",
" \n",
" detection_result.append((file_path.numpy()[0].decode(), annotations))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "29306d7f32df"
},
"source": [
"Once the tracking loop has completed you can update the track length (`seq_length`) for each annotation from the `track_length_for_id` dict:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "oPSfnQ1o04Rx"
},
"outputs": [],
"source": [
"def update_annotation_lengths(detection_result, track_length_for_id):\n",
" new_result = []\n",
" for file_path, annotations in detection_result:\n",
" new_annotations = []\n",
" for anno in annotations:\n",
" anno = anno.replace(seq_length=track_length_for_id[anno.seq_id])\n",
" new_annotations.append(anno)\n",
" new_result.append((file_path, new_annotations))\n",
" return new_result"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "zda914lv1o_v"
},
"outputs": [],
"source": [
"detection_result = update_annotation_lengths(detection_result, track_length_for_id)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "QkpmYRyFAMlM"
},
"source": [
"## Output the detection results and play the result video\n",
"\n",
"Once the inference is done, we draw the bounding boxes and track information onto each frame's image. Finally, we combine all frames into a video for visualisation."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "gWMJG7g95MGk"
},
"outputs": [],
"source": [
"detection_full_video_path = \"COTS_detection_full_size.mp4\"\n",
"detect_video_writer = cv2.VideoWriter(\n",
" filename=detection_full_video_path,\n",
" fourcc=cv2.VideoWriter_fourcc(*\"MP4V\"), \n",
" fps=15, \n",
" frameSize=size)\n",
"\n",
"for file_path, annotations in tqdm(detection_result):\n",
" image = cv2.imread(file_path)\n",
" for anno in annotations:\n",
" anno.det.bbox.draw(image, label=anno.annotation_str(), color=(0, 140, 255))\n",
" detect_video_writer.write(image)\n",
"cv2.destroyAllWindows()\n",
"\n",
"detect_video_writer.release()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "9s1myz67jcV8"
},
"outputs": [],
"source": [
"subprocess.check_call([\n",
" \"ffmpeg\",\"-y\", \"-i\", detection_full_video_path,\n",
" \"-vf\",\"scale=800:-1\",\n",
" \"-crf\", \"18\",\n",
" \"-preset\", \"veryfast\",\n",
" \"-vcodec\", \"libx264\", detection_small_video_path])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "wsK5cvX5jkL7"
},
"outputs": [],
"source": [
"embed_video_file(detection_small_video_path)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "n1oOgMR2zzIl"
},
"source": [
"The output video is now saved as movie at `detection_full_video_path`. You can download your video by uncommenting the following code."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "tyHucK8lbGXk"
},
"outputs": [],
"source": [
"#try:\n",
"# from google.colab import files\n",
"# files.download(detection_full_video_path)\n",
"#except ImportError:\n",
"# pass"
]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"collapsed_sections": [],
"name": "crown_of_thorns_starfish_detection_pipeline.ipynb",
"toc_visible": true
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
}
},
"nbformat": 4,
"nbformat_minor": 0
}
# Mask R-CNN with deep mask heads
This project brings insights from the DeepMAC model into the Mask-RCNN
architecture. Please see the paper
[The surprising impact of mask-head architecture on novel class segmentation](https://arxiv.org/abs/2104.00613)
for more details.
## Code structure
* This folder contains forks of a few Mask R-CNN files and repurposes them to
support deep mask heads.
* To see the benefits of using deep mask heads, it is important to train the
mask head with only groundtruth boxes. This is configured via the
`task.model.use_gt_boxes_for_masks` flag.
* Architecture of the mask head can be changed via the config value
`task.model.mask_head.convnet_variant`. Supported values are `"default"`,
`"hourglass20"`, `"hourglass52"`, and `"hourglass100"`.
* The flag `task.model.mask_head.class_agnostic` trains the model in class
agnostic mode and `task.allowed_mask_class_ids` controls which classes are
allowed to have masks during training.
* Majority of experiments and ablations from the paper are perfomed with the
[DeepMAC model](https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/deepmac.md)
in the Object Detection API code base.
## Prerequisites
### Prepare dataset
Use [create_coco_tf_record.py](https://github.com/tensorflow/models/blob/master/official/vision/data/create_coco_tf_record.py) to create
the COCO dataset. The data needs to be store in a
[Google cloud storage bucket](https://cloud.google.com/storage/docs/creating-buckets)
so that it can be accessed by the TPU.
### Start a TPU v3-32 instance
See [TPU Quickstart](https://cloud.google.com/tpu/docs/quickstart) for
instructions. An example command would look like:
```shell
ctpu up --name <tpu-name> --zone <zone> --tpu-size=v3-32 --tf-version nightly
```
This model requires TF version `>= 2.5`. Currently, that is only available via a
`nightly` build on Cloud.
### Install requirements
SSH into the TPU host with `gcloud compute ssh <tpu-name>` and execute the
following.
```shell
$ git clone https://github.com/tensorflow/models.git
$ cd models
$ pip3 install -r official/requirements.txt
```
## Training Models
The configurations can be found in the `configs/experiments` directory. You can
launch a training job by executing.
```shell
$ export CONFIG=./official/projects/deepmac_maskrcnn/configs/experiments/deep_mask_head_rcnn_voc_r50.yaml
$ export MODEL_DIR="gs://<path-for-checkpoints>"
$ export ANNOTAION_FILE="gs://<path-to-coco-annotation-json>"
$ export TRAIN_DATA="gs://<path-to-train-data>"
$ export EVAL_DATA="gs://<path-to-eval-data>"
# Overrides to access data. These can also be changed in the config file.
$ export OVERRIDES="task.validation_data.input_path=${EVAL_DATA},\
task.train_data.input_path=${TRAIN_DATA},\
task.annotation_file=${ANNOTAION_FILE},\
runtime.distribution_strategy=tpu"
$ python3 -m official.projects.deepmac_maskrcnn.train \
--logtostderr \
--mode=train_and_eval \
--experiment=deep_mask_head_rcnn_resnetfpn_coco \
--model_dir=$MODEL_DIR \
--config_file=$CONFIG \
--params_override=$OVERRIDES\
--tpu=<tpu-name>
```
`CONFIG_FILE` can be any file in the `configs/experiments` directory.
When using SpineNet models, please specify
`--experiment=deep_mask_head_rcnn_spinenet_coco`
**Note:** The default eval batch size of 32 discards some samples during
validation. For accurate vaidation statistics, launch a dedicated eval job on
TPU `v3-8` and set batch size to 8.
## Configurations
In the following table, we report the Mask mAP of our models on the non-VOC
classes when only training with masks for the VOC calsses. Performance is
measured on the `coco-val2017` set.
Backbone | Mask head | Config name | Mask mAP
:------------| :----------- | :-----------------------------------------------| -------:
ResNet-50 | Default | `deep_mask_head_rcnn_voc_r50.yaml` | 25.9
ResNet-50 | Hourglass-52 | `deep_mask_head_rcnn_voc_r50_hg52.yaml` | 33.1
ResNet-101 | Hourglass-52 | `deep_mask_head_rcnn_voc_r101_hg52.yaml` | 34.4
SpienNet-143 | Hourglass-52 | `deep_mask_head_rcnn_voc_spinenet143_hg52.yaml` | 38.7
## Checkpoints
This model takes Image + boxes as input and produces per-box instance
masks as output.
* [Mask-RCNN SpineNet backbone](https://storage.googleapis.com/tf_model_garden/vision/deepmac_maskrcnn/deepmarc_spinenet.zip)
## See also
* [DeepMAC model](https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/deepmac.md)
in the Object Detection API code base.
* Project website - [git.io/deepmac](https://google.github.io/deepmac/)
## Citation
```
@misc{birodkar2021surprising,
title={The surprising impact of mask-head architecture on novel class segmentation},
author={Vighnesh Birodkar and Zhichao Lu and Siyang Li and Vivek Rathod and Jonathan Huang},
year={2021},
eprint={2104.00613},
archivePrefix={arXiv},
primaryClass={cs.CV}
}
```
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Imports to configure Mask R-CNN with deep mask heads."""
# pylint: disable=unused-import
from official.projects.deepmac_maskrcnn.tasks import deep_mask_head_rcnn
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Configuration for Mask R-CNN with deep mask heads."""
import dataclasses
import os
from typing import Optional
from official.core import config_definitions as cfg
from official.core import exp_factory
from official.modeling import optimization
from official.vision.configs import backbones
from official.vision.configs import common
from official.vision.configs import decoders
from official.vision.configs import maskrcnn as maskrcnn_config
from official.vision.configs import retinanet as retinanet_config
@dataclasses.dataclass
class DeepMaskHead(maskrcnn_config.MaskHead):
convnet_variant: str = 'default'
@dataclasses.dataclass
class DeepMaskHeadRCNN(maskrcnn_config.MaskRCNN):
mask_head: Optional[DeepMaskHead] = DeepMaskHead()
use_gt_boxes_for_masks: bool = False
@dataclasses.dataclass
class DeepMaskHeadRCNNTask(maskrcnn_config.MaskRCNNTask):
"""Configuration for the deep mask head R-CNN task."""
model: DeepMaskHeadRCNN = DeepMaskHeadRCNN()
@exp_factory.register_config_factory('deep_mask_head_rcnn_resnetfpn_coco')
def deep_mask_head_rcnn_resnetfpn_coco() -> cfg.ExperimentConfig:
"""COCO object detection with Mask R-CNN with deep mask heads."""
global_batch_size = 64
steps_per_epoch = int(retinanet_config.COCO_TRAIN_EXAMPLES /
global_batch_size)
coco_val_samples = 5000
config = cfg.ExperimentConfig(
runtime=cfg.RuntimeConfig(mixed_precision_dtype='bfloat16'),
task=DeepMaskHeadRCNNTask(
init_checkpoint='gs://cloud-tpu-checkpoints/vision-2.0/resnet50_imagenet/ckpt-28080',
init_checkpoint_modules='backbone',
annotation_file=os.path.join(maskrcnn_config.COCO_INPUT_PATH_BASE,
'instances_val2017.json'),
model=DeepMaskHeadRCNN(
num_classes=91, input_size=[1024, 1024, 3], include_mask=True), # pytype: disable=wrong-keyword-args
losses=maskrcnn_config.Losses(l2_weight_decay=0.00004),
train_data=maskrcnn_config.DataConfig(
input_path=os.path.join(maskrcnn_config.COCO_INPUT_PATH_BASE,
'train*'),
is_training=True,
global_batch_size=global_batch_size,
parser=maskrcnn_config.Parser(
aug_rand_hflip=True, aug_scale_min=0.8, aug_scale_max=1.25)),
validation_data=maskrcnn_config.DataConfig(
input_path=os.path.join(maskrcnn_config.COCO_INPUT_PATH_BASE,
'val*'),
is_training=False,
global_batch_size=8)), # pytype: disable=wrong-keyword-args
trainer=cfg.TrainerConfig(
train_steps=22500,
validation_steps=coco_val_samples // 8,
validation_interval=steps_per_epoch,
steps_per_loop=steps_per_epoch,
summary_interval=steps_per_epoch,
checkpoint_interval=steps_per_epoch,
optimizer_config=optimization.OptimizationConfig({
'optimizer': {
'type': 'sgd',
'sgd': {
'momentum': 0.9
}
},
'learning_rate': {
'type': 'stepwise',
'stepwise': {
'boundaries': [15000, 20000],
'values': [0.12, 0.012, 0.0012],
}
},
'warmup': {
'type': 'linear',
'linear': {
'warmup_steps': 500,
'warmup_learning_rate': 0.0067
}
}
})),
restrictions=[
'task.train_data.is_training != None',
'task.validation_data.is_training != None'
])
return config
@exp_factory.register_config_factory('deep_mask_head_rcnn_spinenet_coco')
def deep_mask_head_rcnn_spinenet_coco() -> cfg.ExperimentConfig:
"""COCO object detection with Mask R-CNN with SpineNet backbone."""
steps_per_epoch = 463
coco_val_samples = 5000
train_batch_size = 256
eval_batch_size = 8
config = cfg.ExperimentConfig(
runtime=cfg.RuntimeConfig(mixed_precision_dtype='bfloat16'),
task=DeepMaskHeadRCNNTask(
annotation_file=os.path.join(maskrcnn_config.COCO_INPUT_PATH_BASE,
'instances_val2017.json'), # pytype: disable=wrong-keyword-args
model=DeepMaskHeadRCNN(
backbone=backbones.Backbone(
type='spinenet',
spinenet=backbones.SpineNet(
model_id='49',
min_level=3,
max_level=7,
)),
decoder=decoders.Decoder(
type='identity', identity=decoders.Identity()),
anchor=maskrcnn_config.Anchor(anchor_size=3),
norm_activation=common.NormActivation(use_sync_bn=True),
num_classes=91,
input_size=[640, 640, 3],
min_level=3,
max_level=7,
include_mask=True), # pytype: disable=wrong-keyword-args
losses=maskrcnn_config.Losses(l2_weight_decay=0.00004),
train_data=maskrcnn_config.DataConfig(
input_path=os.path.join(maskrcnn_config.COCO_INPUT_PATH_BASE,
'train*'),
is_training=True,
global_batch_size=train_batch_size,
parser=maskrcnn_config.Parser(
aug_rand_hflip=True, aug_scale_min=0.5, aug_scale_max=2.0)),
validation_data=maskrcnn_config.DataConfig(
input_path=os.path.join(maskrcnn_config.COCO_INPUT_PATH_BASE,
'val*'),
is_training=False,
global_batch_size=eval_batch_size,
drop_remainder=False)), # pytype: disable=wrong-keyword-args
trainer=cfg.TrainerConfig(
train_steps=steps_per_epoch * 350,
validation_steps=coco_val_samples // eval_batch_size,
validation_interval=steps_per_epoch,
steps_per_loop=steps_per_epoch,
summary_interval=steps_per_epoch,
checkpoint_interval=steps_per_epoch,
optimizer_config=optimization.OptimizationConfig({
'optimizer': {
'type': 'sgd',
'sgd': {
'momentum': 0.9
}
},
'learning_rate': {
'type': 'stepwise',
'stepwise': {
'boundaries': [
steps_per_epoch * 320, steps_per_epoch * 340
],
'values': [0.32, 0.032, 0.0032],
}
},
'warmup': {
'type': 'linear',
'linear': {
'warmup_steps': 2000,
'warmup_learning_rate': 0.0067
}
}
})),
restrictions=[
'task.train_data.is_training != None',
'task.validation_data.is_training != None',
'task.model.min_level == task.model.backbone.spinenet.min_level',
'task.model.max_level == task.model.backbone.spinenet.max_level',
])
return config
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -16,7 +16,7 @@
import tensorflow as tf
from official.vision.beta.projects.deepmac_maskrcnn.configs import deep_mask_head_rcnn
from official.projects.deepmac_maskrcnn.configs import deep_mask_head_rcnn
class DeepMaskHeadRcnnConfigTest(tf.test.TestCase):
......
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Instance prediction heads."""
# Import libraries
from absl import logging
import tensorflow as tf
from official.modeling import tf_utils
from official.projects.deepmac_maskrcnn.modeling.heads import hourglass_network
class DeepMaskHead(tf.keras.layers.Layer):
"""Creates a mask head."""
def __init__(self,
num_classes,
upsample_factor=2,
num_convs=4,
num_filters=256,
use_separable_conv=False,
activation='relu',
use_sync_bn=False,
norm_momentum=0.99,
norm_epsilon=0.001,
kernel_regularizer=None,
bias_regularizer=None,
class_agnostic=False,
convnet_variant='default',
**kwargs):
"""Initializes a mask head.
Args:
num_classes: An `int` of the number of classes.
upsample_factor: An `int` that indicates the upsample factor to generate
the final predicted masks. It should be >= 1.
num_convs: An `int` number that represents the number of the intermediate
convolution layers before the mask prediction layers.
num_filters: An `int` number that represents the number of filters of the
intermediate convolution layers.
use_separable_conv: A `bool` that indicates whether the separable
convolution layers is used.
activation: A `str` that indicates which activation is used, e.g. 'relu',
'swish', etc.
use_sync_bn: A `bool` that indicates whether to use synchronized batch
normalization across different replicas.
norm_momentum: A `float` of normalization momentum for the moving average.
norm_epsilon: A `float` added to variance to avoid dividing by zero.
kernel_regularizer: A `tf.keras.regularizers.Regularizer` object for
Conv2D. Default is None.
bias_regularizer: A `tf.keras.regularizers.Regularizer` object for Conv2D.
class_agnostic: A `bool`. If set, we use a single channel mask head that
is shared between all classes.
convnet_variant: A `str` denoting the architecture of network used in the
head. Supported options are 'default', 'hourglass20', 'hourglass52'
and 'hourglass100'.
**kwargs: Additional keyword arguments to be passed.
"""
super(DeepMaskHead, self).__init__(**kwargs)
self._config_dict = {
'num_classes': num_classes,
'upsample_factor': upsample_factor,
'num_convs': num_convs,
'num_filters': num_filters,
'use_separable_conv': use_separable_conv,
'activation': activation,
'use_sync_bn': use_sync_bn,
'norm_momentum': norm_momentum,
'norm_epsilon': norm_epsilon,
'kernel_regularizer': kernel_regularizer,
'bias_regularizer': bias_regularizer,
'class_agnostic': class_agnostic,
'convnet_variant': convnet_variant,
}
if tf.keras.backend.image_data_format() == 'channels_last':
self._bn_axis = -1
else:
self._bn_axis = 1
self._activation = tf_utils.get_activation(activation)
def _get_conv_op_and_kwargs(self):
conv_op = (tf.keras.layers.SeparableConv2D
if self._config_dict['use_separable_conv']
else tf.keras.layers.Conv2D)
conv_kwargs = {
'filters': self._config_dict['num_filters'],
'kernel_size': 3,
'padding': 'same',
}
if self._config_dict['use_separable_conv']:
conv_kwargs.update({
'depthwise_initializer': tf.keras.initializers.VarianceScaling(
scale=2, mode='fan_out', distribution='untruncated_normal'),
'pointwise_initializer': tf.keras.initializers.VarianceScaling(
scale=2, mode='fan_out', distribution='untruncated_normal'),
'bias_initializer': tf.zeros_initializer(),
'depthwise_regularizer': self._config_dict['kernel_regularizer'],
'pointwise_regularizer': self._config_dict['kernel_regularizer'],
'bias_regularizer': self._config_dict['bias_regularizer'],
})
else:
conv_kwargs.update({
'kernel_initializer': tf.keras.initializers.VarianceScaling(
scale=2, mode='fan_out', distribution='untruncated_normal'),
'bias_initializer': tf.zeros_initializer(),
'kernel_regularizer': self._config_dict['kernel_regularizer'],
'bias_regularizer': self._config_dict['bias_regularizer'],
})
return conv_op, conv_kwargs
def _get_bn_op_and_kwargs(self):
bn_op = (tf.keras.layers.experimental.SyncBatchNormalization
if self._config_dict['use_sync_bn']
else tf.keras.layers.BatchNormalization)
bn_kwargs = {
'axis': self._bn_axis,
'momentum': self._config_dict['norm_momentum'],
'epsilon': self._config_dict['norm_epsilon'],
}
return bn_op, bn_kwargs
def build(self, input_shape):
"""Creates the variables of the head."""
conv_op, conv_kwargs = self._get_conv_op_and_kwargs()
self._build_convnet_variant()
self._deconv = tf.keras.layers.Conv2DTranspose(
filters=self._config_dict['num_filters'],
kernel_size=self._config_dict['upsample_factor'],
strides=self._config_dict['upsample_factor'],
padding='valid',
kernel_initializer=tf.keras.initializers.VarianceScaling(
scale=2, mode='fan_out', distribution='untruncated_normal'),
bias_initializer=tf.zeros_initializer(),
kernel_regularizer=self._config_dict['kernel_regularizer'],
bias_regularizer=self._config_dict['bias_regularizer'],
name='mask-upsampling')
bn_op, bn_kwargs = self._get_bn_op_and_kwargs()
self._deconv_bn = bn_op(name='mask-deconv-bn', **bn_kwargs)
if self._config_dict['class_agnostic']:
num_filters = 1
else:
num_filters = self._config_dict['num_classes']
conv_kwargs = {
'filters': num_filters,
'kernel_size': 1,
'padding': 'valid',
}
if self._config_dict['use_separable_conv']:
conv_kwargs.update({
'depthwise_initializer': tf.keras.initializers.VarianceScaling(
scale=2, mode='fan_out', distribution='untruncated_normal'),
'pointwise_initializer': tf.keras.initializers.VarianceScaling(
scale=2, mode='fan_out', distribution='untruncated_normal'),
'bias_initializer': tf.zeros_initializer(),
'depthwise_regularizer': self._config_dict['kernel_regularizer'],
'pointwise_regularizer': self._config_dict['kernel_regularizer'],
'bias_regularizer': self._config_dict['bias_regularizer'],
})
else:
conv_kwargs.update({
'kernel_initializer': tf.keras.initializers.VarianceScaling(
scale=2, mode='fan_out', distribution='untruncated_normal'),
'bias_initializer': tf.zeros_initializer(),
'kernel_regularizer': self._config_dict['kernel_regularizer'],
'bias_regularizer': self._config_dict['bias_regularizer'],
})
self._mask_regressor = conv_op(name='mask-logits', **conv_kwargs)
super(DeepMaskHead, self).build(input_shape)
def call(self, inputs, training=None):
"""Forward pass of mask branch for the Mask-RCNN model.
Args:
inputs: A `list` of two tensors where
inputs[0]: A `tf.Tensor` of shape [batch_size, num_instances,
roi_height, roi_width, roi_channels], representing the ROI features.
inputs[1]: A `tf.Tensor` of shape [batch_size, num_instances],
representing the classes of the ROIs.
training: A `bool` indicating whether it is in `training` mode.
Returns:
mask_outputs: A `tf.Tensor` of shape
[batch_size, num_instances, roi_height * upsample_factor,
roi_width * upsample_factor], representing the mask predictions.
"""
roi_features, roi_classes = inputs
features_shape = tf.shape(roi_features)
batch_size, num_rois, height, width, filters = (
features_shape[0], features_shape[1], features_shape[2],
features_shape[3], features_shape[4])
if batch_size is None:
batch_size = tf.shape(roi_features)[0]
x = tf.reshape(roi_features, [-1, height, width, filters])
x = self._call_convnet_variant(x)
x = self._deconv(x)
x = self._deconv_bn(x)
x = self._activation(x)
logits = self._mask_regressor(x)
mask_height = height * self._config_dict['upsample_factor']
mask_width = width * self._config_dict['upsample_factor']
if self._config_dict['class_agnostic']:
logits = tf.reshape(logits, [-1, num_rois, mask_height, mask_width, 1])
else:
logits = tf.reshape(
logits,
[-1, num_rois, mask_height, mask_width,
self._config_dict['num_classes']])
batch_indices = tf.tile(
tf.expand_dims(tf.range(batch_size), axis=1), [1, num_rois])
mask_indices = tf.tile(
tf.expand_dims(tf.range(num_rois), axis=0), [batch_size, 1])
if self._config_dict['class_agnostic']:
class_gather_indices = tf.zeros_like(roi_classes, dtype=tf.int32)
else:
class_gather_indices = tf.cast(roi_classes, dtype=tf.int32)
gather_indices = tf.stack(
[batch_indices, mask_indices, class_gather_indices],
axis=2)
mask_outputs = tf.gather_nd(
tf.transpose(logits, [0, 1, 4, 2, 3]), gather_indices)
return mask_outputs
def _build_convnet_variant(self):
variant = self._config_dict['convnet_variant']
if variant == 'default':
bn_op, bn_kwargs = self._get_bn_op_and_kwargs()
self._convs = []
self._conv_norms = []
for i in range(self._config_dict['num_convs']):
conv_name = 'mask-conv_{}'.format(i)
conv_op, conv_kwargs = self._get_conv_op_and_kwargs()
self._convs.append(conv_op(name=conv_name, **conv_kwargs))
bn_name = 'mask-conv-bn_{}'.format(i)
self._conv_norms.append(bn_op(name=bn_name, **bn_kwargs))
elif variant == 'hourglass20':
logging.info('Using hourglass 20 network.')
self._hourglass = hourglass_network.hourglass_20(
self._config_dict['num_filters'], initial_downsample=False)
elif variant == 'hourglass52':
logging.info('Using hourglass 52 network.')
self._hourglass = hourglass_network.hourglass_52(
self._config_dict['num_filters'], initial_downsample=False)
elif variant == 'hourglass100':
logging.info('Using hourglass 100 network.')
self._hourglass = hourglass_network.hourglass_100(
self._config_dict['num_filters'], initial_downsample=False)
else:
raise ValueError('Unknown ConvNet variant - {}'.format(variant))
def _call_convnet_variant(self, x):
variant = self._config_dict['convnet_variant']
if variant == 'default':
for conv, bn in zip(self._convs, self._conv_norms):
x = conv(x)
x = bn(x)
x = self._activation(x)
return x
elif variant == 'hourglass20':
return self._hourglass(x)[-1]
elif variant == 'hourglass52':
return self._hourglass(x)[-1]
elif variant == 'hourglass100':
return self._hourglass(x)[-1]
else:
raise ValueError('Unknown ConvNet variant - {}'.format(variant))
def get_config(self):
return self._config_dict
@classmethod
def from_config(cls, config):
return cls(**config)
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for instance_heads.py."""
# Import libraries
from absl.testing import parameterized
import numpy as np
import tensorflow as tf
from official.projects.deepmac_maskrcnn.modeling.heads import instance_heads as deep_instance_heads
class MaskHeadTest(parameterized.TestCase, tf.test.TestCase):
@parameterized.parameters(
(1, 1, False),
(1, 2, False),
(2, 1, False),
(2, 2, False),
)
def test_forward(self, upsample_factor, num_convs, use_sync_bn):
mask_head = deep_instance_heads.DeepMaskHead(
num_classes=3,
upsample_factor=upsample_factor,
num_convs=num_convs,
num_filters=16,
use_separable_conv=False,
activation='relu',
use_sync_bn=use_sync_bn,
norm_momentum=0.99,
norm_epsilon=0.001,
kernel_regularizer=None,
bias_regularizer=None,
)
roi_features = np.random.rand(2, 10, 14, 14, 16)
roi_classes = np.zeros((2, 10))
masks = mask_head([roi_features, roi_classes])
self.assertAllEqual(
masks.numpy().shape,
[2, 10, 14 * upsample_factor, 14 * upsample_factor])
def test_serialize_deserialize(self):
mask_head = deep_instance_heads.DeepMaskHead(
num_classes=3,
upsample_factor=2,
num_convs=1,
num_filters=256,
use_separable_conv=False,
activation='relu',
use_sync_bn=False,
norm_momentum=0.99,
norm_epsilon=0.001,
kernel_regularizer=None,
bias_regularizer=None,
)
config = mask_head.get_config()
new_mask_head = deep_instance_heads.DeepMaskHead.from_config(config)
self.assertAllEqual(
mask_head.get_config(), new_mask_head.get_config())
def test_forward_class_agnostic(self):
mask_head = deep_instance_heads.DeepMaskHead(
num_classes=3,
class_agnostic=True
)
roi_features = np.random.rand(2, 10, 14, 14, 16)
roi_classes = np.zeros((2, 10))
masks = mask_head([roi_features, roi_classes])
self.assertAllEqual(masks.numpy().shape, [2, 10, 28, 28])
def test_instance_head_hourglass(self):
mask_head = deep_instance_heads.DeepMaskHead(
num_classes=3,
class_agnostic=True,
convnet_variant='hourglass20',
num_filters=32,
upsample_factor=2
)
roi_features = np.random.rand(2, 10, 16, 16, 16)
roi_classes = np.zeros((2, 10))
masks = mask_head([roi_features, roi_classes])
self.assertAllEqual(masks.numpy().shape, [2, 10, 32, 32])
if __name__ == '__main__':
tf.test.main()
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Mask R-CNN model."""
from typing import List, Mapping, Optional, Union
# Import libraries
from absl import logging
import tensorflow as tf
from official.vision.modeling import maskrcnn_model
def resize_as(source, size):
source = tf.transpose(source, (0, 2, 3, 1))
source = tf.image.resize(source, (size, size))
return tf.transpose(source, (0, 3, 1, 2))
class DeepMaskRCNNModel(maskrcnn_model.MaskRCNNModel):
"""The Mask R-CNN model."""
def __init__(self,
backbone: tf.keras.Model,
decoder: tf.keras.Model,
rpn_head: tf.keras.layers.Layer,
detection_head: Union[tf.keras.layers.Layer,
List[tf.keras.layers.Layer]],
roi_generator: tf.keras.layers.Layer,
roi_sampler: Union[tf.keras.layers.Layer,
List[tf.keras.layers.Layer]],
roi_aligner: tf.keras.layers.Layer,
detection_generator: tf.keras.layers.Layer,
mask_head: Optional[tf.keras.layers.Layer] = None,
mask_sampler: Optional[tf.keras.layers.Layer] = None,
mask_roi_aligner: Optional[tf.keras.layers.Layer] = None,
class_agnostic_bbox_pred: bool = False,
cascade_class_ensemble: bool = False,
min_level: Optional[int] = None,
max_level: Optional[int] = None,
num_scales: Optional[int] = None,
aspect_ratios: Optional[List[float]] = None,
anchor_size: Optional[float] = None,
use_gt_boxes_for_masks=False,
**kwargs):
"""Initializes the Mask R-CNN model.
Args:
backbone: `tf.keras.Model`, the backbone network.
decoder: `tf.keras.Model`, the decoder network.
rpn_head: the RPN head.
detection_head: the detection head or a list of heads.
roi_generator: the ROI generator.
roi_sampler: a single ROI sampler or a list of ROI samplers for cascade
detection heads.
roi_aligner: the ROI aligner.
detection_generator: the detection generator.
mask_head: the mask head.
mask_sampler: the mask sampler.
mask_roi_aligner: the ROI alginer for mask prediction.
class_agnostic_bbox_pred: if True, perform class agnostic bounding box
prediction. Needs to be `True` for Cascade RCNN models.
cascade_class_ensemble: if True, ensemble classification scores over all
detection heads.
min_level: Minimum level in output feature maps.
max_level: Maximum level in output feature maps.
num_scales: A number representing intermediate scales added on each level.
For instances, num_scales=2 adds one additional intermediate anchor
scales [2^0, 2^0.5] on each level.
aspect_ratios: A list representing the aspect raito anchors added on each
level. The number indicates the ratio of width to height. For instances,
aspect_ratios=[1.0, 2.0, 0.5] adds three anchors on each scale level.
anchor_size: A number representing the scale of size of the base anchor to
the feature stride 2^level.
use_gt_boxes_for_masks: bool, if set, crop using groundtruth boxes instead
of proposals for training mask head
**kwargs: keyword arguments to be passed.
"""
super(DeepMaskRCNNModel, self).__init__(
backbone=backbone,
decoder=decoder,
rpn_head=rpn_head,
detection_head=detection_head,
roi_generator=roi_generator,
roi_sampler=roi_sampler,
roi_aligner=roi_aligner,
detection_generator=detection_generator,
mask_head=mask_head,
mask_sampler=mask_sampler,
mask_roi_aligner=mask_roi_aligner,
class_agnostic_bbox_pred=class_agnostic_bbox_pred,
cascade_class_ensemble=cascade_class_ensemble,
min_level=min_level,
max_level=max_level,
num_scales=num_scales,
aspect_ratios=aspect_ratios,
anchor_size=anchor_size,
**kwargs)
self._config_dict['use_gt_boxes_for_masks'] = use_gt_boxes_for_masks
def call(self,
images: tf.Tensor,
image_shape: tf.Tensor,
anchor_boxes: Optional[Mapping[str, tf.Tensor]] = None,
gt_boxes: Optional[tf.Tensor] = None,
gt_classes: Optional[tf.Tensor] = None,
gt_masks: Optional[tf.Tensor] = None,
training: Optional[bool] = None) -> Mapping[str, tf.Tensor]:
model_outputs, intermediate_outputs = self._call_box_outputs(
images=images, image_shape=image_shape, anchor_boxes=anchor_boxes,
gt_boxes=gt_boxes, gt_classes=gt_classes, training=training)
if not self._include_mask:
return model_outputs
model_mask_outputs = self._call_mask_outputs(
model_box_outputs=model_outputs,
features=model_outputs['decoder_features'],
current_rois=intermediate_outputs['current_rois'],
matched_gt_indices=intermediate_outputs['matched_gt_indices'],
matched_gt_boxes=intermediate_outputs['matched_gt_boxes'],
matched_gt_classes=intermediate_outputs['matched_gt_classes'],
gt_masks=gt_masks,
gt_classes=gt_classes,
gt_boxes=gt_boxes,
training=training)
model_outputs.update(model_mask_outputs)
return model_outputs
def call_images_and_boxes(self, images, boxes):
"""Predict masks given an image and bounding boxes."""
_, decoder_features = self._get_backbone_and_decoder_features(images)
boxes_shape = tf.shape(boxes)
batch_size, num_boxes = boxes_shape[0], boxes_shape[1]
classes = tf.zeros((batch_size, num_boxes), dtype=tf.int32)
_, mask_probs = self._features_to_mask_outputs(
decoder_features, boxes, classes)
return {
'detection_masks': mask_probs
}
def _call_mask_outputs(
self,
model_box_outputs: Mapping[str, tf.Tensor],
features: tf.Tensor,
current_rois: tf.Tensor,
matched_gt_indices: tf.Tensor,
matched_gt_boxes: tf.Tensor,
matched_gt_classes: tf.Tensor,
gt_masks: tf.Tensor,
gt_classes: tf.Tensor,
gt_boxes: tf.Tensor,
training: Optional[bool] = None) -> Mapping[str, tf.Tensor]:
model_outputs = dict(model_box_outputs)
if training:
if self._config_dict['use_gt_boxes_for_masks']:
mask_size = (
self.mask_roi_aligner._config_dict['crop_size'] * # pylint:disable=protected-access
self.mask_head._config_dict['upsample_factor'] # pylint:disable=protected-access
)
gt_masks = resize_as(source=gt_masks, size=mask_size)
logging.info('Using GT class and mask targets.')
model_outputs.update({
'mask_class_targets': gt_classes,
'mask_targets': gt_masks,
})
else:
rois, roi_classes, roi_masks = self.mask_sampler(
current_rois, matched_gt_boxes, matched_gt_classes,
matched_gt_indices, gt_masks)
roi_masks = tf.stop_gradient(roi_masks)
model_outputs.update({
'mask_class_targets': roi_classes,
'mask_targets': roi_masks,
})
else:
rois = model_outputs['detection_boxes']
roi_classes = model_outputs['detection_classes']
# Mask RoI align.
if training and self._config_dict['use_gt_boxes_for_masks']:
logging.info('Using GT mask roi features.')
roi_aligner_boxes = gt_boxes
mask_head_classes = gt_classes
else:
roi_aligner_boxes = rois
mask_head_classes = roi_classes
mask_logits, mask_probs = self._features_to_mask_outputs(
features, roi_aligner_boxes, mask_head_classes)
if training:
model_outputs.update({
'mask_outputs': mask_logits,
})
else:
model_outputs.update({
'detection_masks': mask_probs,
})
return model_outputs
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for maskrcnn_model.py."""
# Import libraries
from absl.testing import parameterized
import numpy as np
import tensorflow as tf
from official.projects.deepmac_maskrcnn.modeling import maskrcnn_model
from official.projects.deepmac_maskrcnn.modeling.heads import instance_heads as deep_instance_heads
from official.vision.modeling.backbones import resnet
from official.vision.modeling.decoders import fpn
from official.vision.modeling.heads import dense_prediction_heads
from official.vision.modeling.heads import instance_heads
from official.vision.modeling.layers import detection_generator
from official.vision.modeling.layers import mask_sampler
from official.vision.modeling.layers import roi_aligner
from official.vision.modeling.layers import roi_generator
from official.vision.modeling.layers import roi_sampler
from official.vision.ops import anchor
def construct_model_and_anchors(image_size, use_gt_boxes_for_masks):
num_classes = 3
min_level = 3
max_level = 4
num_scales = 3
aspect_ratios = [1.0]
anchor_boxes = anchor.Anchor(
min_level=min_level,
max_level=max_level,
num_scales=num_scales,
aspect_ratios=aspect_ratios,
anchor_size=3,
image_size=image_size).multilevel_boxes
num_anchors_per_location = len(aspect_ratios) * num_scales
input_specs = tf.keras.layers.InputSpec(shape=[None, None, None, 3])
backbone = resnet.ResNet(model_id=50, input_specs=input_specs)
decoder = fpn.FPN(
min_level=min_level,
max_level=max_level,
input_specs=backbone.output_specs)
rpn_head = dense_prediction_heads.RPNHead(
min_level=min_level,
max_level=max_level,
num_anchors_per_location=num_anchors_per_location)
detection_head = instance_heads.DetectionHead(
num_classes=num_classes)
roi_generator_obj = roi_generator.MultilevelROIGenerator()
roi_sampler_obj = roi_sampler.ROISampler()
roi_aligner_obj = roi_aligner.MultilevelROIAligner()
detection_generator_obj = detection_generator.DetectionGenerator()
mask_head = deep_instance_heads.DeepMaskHead(
num_classes=num_classes, upsample_factor=2)
mask_sampler_obj = mask_sampler.MaskSampler(
mask_target_size=28, num_sampled_masks=1)
mask_roi_aligner_obj = roi_aligner.MultilevelROIAligner(crop_size=14)
model = maskrcnn_model.DeepMaskRCNNModel(
backbone,
decoder,
rpn_head,
detection_head,
roi_generator_obj,
roi_sampler_obj,
roi_aligner_obj,
detection_generator_obj,
mask_head,
mask_sampler_obj,
mask_roi_aligner_obj,
use_gt_boxes_for_masks=use_gt_boxes_for_masks)
return model, anchor_boxes
class MaskRCNNModelTest(parameterized.TestCase, tf.test.TestCase):
@parameterized.parameters(
(False, False,),
(False, True,),
(True, False,),
(True, True,),
)
def test_forward(self, use_gt_boxes_for_masks, training):
image_size = (256, 256)
images = np.random.rand(2, image_size[0], image_size[1], 3)
image_shape = np.array([[224, 100], [100, 224]])
model, anchor_boxes = construct_model_and_anchors(
image_size, use_gt_boxes_for_masks)
gt_boxes = tf.zeros((2, 16, 4), dtype=tf.float32)
gt_masks = tf.zeros((2, 16, 32, 32))
gt_classes = tf.zeros((2, 16), dtype=tf.int32)
results = model(images.astype(np.uint8),
image_shape,
anchor_boxes,
gt_boxes,
gt_classes,
gt_masks,
training=training)
self.assertIn('rpn_boxes', results)
self.assertIn('rpn_scores', results)
if training:
self.assertIn('class_targets', results)
self.assertIn('box_targets', results)
self.assertIn('class_outputs', results)
self.assertIn('box_outputs', results)
self.assertIn('mask_outputs', results)
self.assertEqual(results['mask_targets'].shape,
results['mask_outputs'].shape)
else:
self.assertIn('detection_boxes', results)
self.assertIn('detection_scores', results)
self.assertIn('detection_classes', results)
self.assertIn('num_detections', results)
self.assertIn('detection_masks', results)
@parameterized.parameters(
[(1, 5), (1, 10), (1, 15), (2, 5), (2, 10), (2, 15)]
)
def test_image_and_boxes(self, batch_size, num_boxes):
image_size = (640, 640)
images = np.random.rand(1, image_size[0], image_size[1], 3).astype(
np.float32)
model, _ = construct_model_and_anchors(
image_size, use_gt_boxes_for_masks=True)
boxes = np.zeros((1, num_boxes, 4), dtype=np.float32)
boxes[:, :, [2, 3]] = 1.0
boxes = tf.constant(boxes)
results = model.call_images_and_boxes(images, boxes)
self.assertIn('detection_masks', results)
if __name__ == '__main__':
tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment