Commit 43621877 authored by Mark McDonald's avatar Mark McDonald Committed by A. Unique TensorFlower
Browse files

Remove `git clone` instruction.

There is no actual command here, and everything is downloaded, so the instruction is redundant.

PiperOrigin-RevId: 447913568
parent 277dac40
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "xBH8CcrkV3IU"
},
"outputs": [],
"source": [
"#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
"# you may not use this file except in compliance with the License.\n",
"# You may obtain a copy of the License at\n",
"#\n",
"# https://www.apache.org/licenses/LICENSE-2.0\n",
"#\n",
"# Unless required by applicable law or agreed to in writing, software\n",
"# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
"# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
"# See the License for the specific language governing permissions and\n",
"# limitations under the License."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "9CzbXNRovpbc"
},
"source": [
"# Crown-of-Thorns Starfish Detection Pipeline"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Lpb0yoNjiWhw"
},
"source": [
"\u003ctable class=\"tfo-notebook-buttons\" align=\"left\"\u003e\n",
" \u003ctd\u003e\n",
" \u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/models/blob/master/official/projects/cots_detector/COTS_detection_inference_and_tracking.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e\n",
" \u003c/td\u003e\n",
" \u003ctd\u003e\n",
" \u003ca target=\"_blank\" href=\"https://github.com/tensorflow/models/blob/master/official/projects/cots_detector/COTS_detection_inference_and_tracking.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView on GitHub\u003c/a\u003e\n",
" \u003c/td\u003e\n",
"\u003c/table\u003e"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "GUQ1x137ysLD"
},
"source": [
"Coral reefs are some of the most diverse and important ecosystems in the world , however they face a number of rising threats that have resulted in massive global declines. In Australia, outbreaks of the coral-eating crown-of-thorns starfish (COTS) have been shown to cause major coral loss, with just 15 starfish in a hectare being able to strip a reef of 90% of its coral tissue. While COTS naturally exist in the Indo-Pacific, overfishing and excess run-off nutrients have led to massive outbreaks that are devastating already vulnerable coral communities.\n",
"\n",
"Controlling COTS populations is critical to promoting coral growth and resilience, so Google teamed up with Australia’s national science agency, [CSIRO](https://www.csiro.au/en/), to tackle this problem. We trained ML object detection models to help scale underwater surveys, enabling the monitoring and mapping out these harmful invertebrates with the ultimate goal of helping control teams to address and prioritize outbreaks."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "jDiIX2xawkJw"
},
"source": [
"### This notebook\n",
"\n",
"This notebook tutorial shows how to detect COTS using a pre-trained COTS detector implemented in TensorFlow. On top of just running the model on each frame of the video, the tracking code in this notebook aligns detections from frame to frame creating a consistent track for each COTS. Each track is given an id and frame count. Here is an example image from a video of a reef showing labeled COTS starfish.\n",
"\n",
"\u003cimg src=\"https://storage.googleapis.com/download.tensorflow.org/data/cots_detection/COTS_detected_sample.png\"\u003e"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "YxCF1t-Skag8"
},
"source": [
"It is recommended to enable GPU to accelerate the inference. On CPU, this runs for about 40 minutes, but on GPU it takes only 10 minutes. (from colab menu: *Runtime \u003e Change runtime type \u003e Hardware accelerator \u003e select \"GPU\"*)."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "6AL2vLo3q39y"
},
"source": [
"#Setting up the environment\n",
"Clone the inference pipeline respository from Github."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "a4R2T97u442o"
},
"source": [
"Install all needed packages."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "5Gs7XvCGlwlj"
},
"outputs": [],
"source": [
"# remove the existing datascience package to avoid package conflicts in the colab environment\n",
"!pip3 uninstall -y datascience\n",
"!pip3 install -q opencv-python"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "w-UQ87240x5R"
},
"outputs": [],
"source": [
"# Imports\n",
"import base64\n",
"import copy\n",
"import dataclasses\n",
"import glob\n",
"import logging\n",
"import mimetypes\n",
"import os\n",
"import pathlib\n",
"import subprocess\n",
"import time\n",
"import textwrap\n",
"\n",
"from absl import logging as absl_logging\n",
"from IPython import display\n",
"import cv2\n",
"import numpy as np\n",
"import tensorflow as tf\n",
"from tqdm import tqdm"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "gsSclJg4sJbX"
},
"source": [
"Define all needed variables."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "iKMCvnZEXBBT"
},
"outputs": [],
"source": [
"model_name = \"cots_1080_v1\" #@param [\"cots_1080_v1\", \"cots_720_v1\"]\n",
"test_sequence_name = \"test3\" #@param [\"test1\", \"test2\", \"test3\"]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "ORLJSdLq4-gd"
},
"outputs": [],
"source": [
"cots_model = f\"https://storage.googleapis.com/download.tensorflow.org/models/cots_detection/{model_name}.zip\"\n",
"\n",
"# Alternatively, this dataset can be downloaded through CSIRO's Data Access Portal at https://data.csiro.au/collection/csiro:54830v2\n",
"sample_data_link = f\"https://storage.googleapis.com/download.tensorflow.org/data/cots_detection/sample_images.zip\"\n",
"\n",
"preview_video_path = \"preview.mp4\"\n",
"detection_small_video_path = \"COTS_detection.mp4\"\n",
"detection_csv_path = \"detections.csv\""
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "fVq6vNBTxM62"
},
"source": [
"Also, download the trained COTS detection model that matches your preferences above."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "No5jRA1TxXj0"
},
"outputs": [],
"source": [
"model_path = tf.keras.utils.get_file(origin=cots_model)\n",
"# Unzip model\n",
"!mkdir {model_name}\n",
"!unzip -o -q {model_path} -d {model_name}"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "FNwP3s-5xgaF"
},
"source": [
"You also need to retrieve the sample data. This sample data is made up of a series of chronological images."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "DF_c_ZMXdPRN"
},
"outputs": [],
"source": [
"sample_data_path = tf.keras.utils.get_file(origin=sample_data_link)\n",
"# Unzip data\n",
"!mkdir sample_images\n",
"!unzip -o -q {sample_data_path} -d sample_images"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "d0iALUwM0g2p"
},
"source": [
"# Load the model and perform inference and tracking on sample data\n",
"Load trained model from disk and create the inference function `model_fn()`. This might take a little while."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "HXQnNjwl8Beu"
},
"outputs": [],
"source": [
"absl_logging.set_verbosity(absl_logging.ERROR)\n",
"\n",
"tf.config.optimizer.set_experimental_options({'auto_mixed_precision': True})\n",
"tf.config.optimizer.set_jit(True)\n",
"\n",
"model_fn = tf.saved_model.load(model_name).signatures['serving_default']"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "jbZ-7ICCENWG"
},
"source": [
"# Define **OpticalFlowTracker** class and its related classes\n",
"\n",
"These help track the movement of each COTS object throughout the image frames."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "tybwY3eaY803"
},
"outputs": [],
"source": [
"def box_area(x0, y0, x1, y1):\n",
" return (x1 - x0 + 1) * (y1 - y0 + 1)\n",
"\n",
"@dataclasses.dataclass\n",
"class Detection:\n",
" \"\"\"Detection dataclass.\"\"\"\n",
" class_id: int\n",
" score: float\n",
" x0: float\n",
" y0: float\n",
" x1: float\n",
" y1: float\n",
"\n",
" def __repr__(self):\n",
" return (f'Class {self.class_id}, score {self.score}, '\n",
" f'box ({self.x0}, {self.y0}, {self.x1}, {self.y1})')\n",
"\n",
" def area(self):\n",
" return box_area(self.x0, self.y0, self.x1, self.y1)\n",
"\n",
" def iou(self, other):\n",
" overlap_x0 = max(self.x0, other.x0)\n",
" overlap_y0 = max(self.y0, other.y0)\n",
" overlap_x1 = min(self.x1, other.x1)\n",
" overlap_y1 = min(self.y1, other.y1)\n",
" if overlap_x0 \u003c overlap_x1 and overlap_y0 \u003c overlap_y1:\n",
" overlap_area = box_area(overlap_x0, overlap_y0, overlap_x1,\n",
" overlap_y1)\n",
" return overlap_area / (self.area() + other.area() - overlap_area)\n",
" else:\n",
" return 0\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "gCQFfAkaY_WN"
},
"outputs": [],
"source": [
"class Tracklet:\n",
" def __init__(self, timestamp, detection):\n",
" self.timestamp = timestamp\n",
" # Store a copy here to make sure the coordinates will not be updated\n",
" # when the optical flow propagation runs using another reference to this\n",
" # detection.\n",
" self.detection = copy.deepcopy(detection)\n",
"\n",
" def __repr__(self):\n",
" return f'Time {self.timestamp}, ' + self.detection.__repr__()\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "7qVW1a_YZBgL"
},
"outputs": [],
"source": [
"class Track:\n",
" \"\"\"Tracker entries.\"\"\"\n",
" def __init__(self, id, detection):\n",
" self.id = id\n",
" self.linked_dets = []\n",
" self.det = detection\n",
"\n",
" def __repr__(self):\n",
" result = f'Track {self.id}'\n",
" for linked_det in self.linked_dets:\n",
" result += '\\n' + linked_det.__repr__()\n",
" return result\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "3j2Ka1uGEoz4"
},
"outputs": [],
"source": [
"class OpticalFlowTracker:\n",
" \"\"\"Optical flow tracker.\"\"\"\n",
" def __init__(self, tid, ft=3.0, iou=0.5, tt=2.0, bb=32, size=64, its=20,\n",
" eps=0.03, levels=3):\n",
" self.track_id = tid\n",
" # How long to apply optical flow tracking without getting positive \n",
" # detections (sec).\n",
" self.track_flow_time = ft * 1000\n",
" # Required IoU overlap to link a detection to a track.\n",
" self.overlap_threshold = iou\n",
" # Used to detect if detector needs to be reset.\n",
" self.time_threshold = tt * 1000\n",
" self.border = bb\n",
" # Size of optical flow region.\n",
" self.of_size = (size, size)\n",
" self.of_criteria = (cv2.TermCriteria_COUNT + cv2.TermCriteria_EPS, its, \n",
" eps)\n",
" self.of_levels= levels\n",
"\n",
" self.tracks = []\n",
" self.prev_image = None\n",
" self.prev_time = -1\n",
"\n",
" def update(self, image_bgr, detections, timestamp):\n",
" start = time.time()\n",
" num_optical_flow_calls = 0\n",
"\n",
" image = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2GRAY)\n",
"\n",
" image_w = image.shape[1]\n",
" image_h = image.shape[0]\n",
"\n",
" # Assume tracker is invalid if too much time has passed!\n",
" if (self.prev_time \u003e 0 and\n",
" timestamp - self.prev_time \u003e self.time_threshold):\n",
" logging.info(\n",
" 'Too much time since last update, resetting tracker.')\n",
" self.tracks = []\n",
"\n",
" # Remove tracks which are:\n",
" # - Touching the image edge.\n",
" # - Have existed for a long time without linking a real detection.\n",
" active_tracks = []\n",
" for track in self.tracks:\n",
" if (track.det.x0 \u003c self.border or track.det.y0 \u003c self.border or\n",
" track.det.x1 \u003e= (image_w - self.border) or\n",
" track.det.y1 \u003e= (image_h - self.border)):\n",
" logging.info(f'Removing track {track.id} because it\\'s near the border')\n",
" continue\n",
"\n",
" time_since_last_detection = timestamp - track.linked_dets[-1].timestamp\n",
" if (time_since_last_detection \u003e self.track_flow_time):\n",
" logging.info(f'Removing track {track.id} because it\\'s too old '\n",
" f'({time_since_last_detection:.02f}s)')\n",
" continue\n",
"\n",
" active_tracks.append(track)\n",
"\n",
" self.tracks = active_tracks\n",
"\n",
" # Run optical flow to update existing tracks.\n",
" if self.prev_time \u003e 0:\n",
" # print('Running optical flow propagation.')\n",
" of_params = {\n",
" 'winSize': self.of_size,\n",
" 'maxLevel': self.of_levels,\n",
" 'criteria': self.of_criteria\n",
" }\n",
" for track in self.tracks:\n",
" input_points = np.float32([[[(track.det.x0 + track.det.x1) / 2,\n",
" (track.det.y0 + track.det.y1) / 2]]])\n",
" output_points, status, error = cv2.calcOpticalFlowPyrLK(\n",
" self.prev_image, image, input_points, None, **of_params)\n",
" num_optical_flow_calls += 1\n",
" w = track.det.x1 - track.det.x0\n",
" h = track.det.y1 - track.det.y0\n",
" # print(f'Detection before flow update: {track.det}')\n",
" track.det.x0 = output_points[0][0][0] - w * 0.5\n",
" track.det.y0 = output_points[0][0][1] - h * 0.5\n",
" track.det.x1 = output_points[0][0][0] + w * 0.5\n",
" track.det.y1 = output_points[0][0][1] + h * 0.5\n",
" # print(f'Detection after flow update: {track.det}')\n",
"\n",
"\n",
" # Insert new detections.\n",
" detected_obj_track_ids = set()\n",
"\n",
" for detection in detections:\n",
" if (detection.x0 \u003c self.border or detection.y0 \u003c self.border or\n",
" detection.x1 \u003e= image_w - self.border or\n",
" detection.y1 \u003e= image_h - self.border):\n",
" # print('Skipping detection because it\\'s close to the border.')\n",
" continue\n",
"\n",
" # See if detection can be linked to an existing track.\n",
" linked = False\n",
" overlap_index = 0\n",
" overlap_max = -1000\n",
" for track_index, track in enumerate(self.tracks):\n",
" # print(f'Testing track {track_index}')\n",
" if track.det.class_id != detection.class_id:\n",
" continue\n",
" overlap = detection.iou(track.det)\n",
" if overlap \u003e overlap_max:\n",
" overlap_index = track_index\n",
" overlap_max = overlap\n",
"\n",
" # Link to existing track with maximal IoU.\n",
" if overlap_max \u003e self.overlap_threshold:\n",
" track = self.tracks[overlap_index]\n",
" track.det = detection\n",
" track.linked_dets.append(Tracklet(timestamp, detection))\n",
" detected_obj_track_ids.add(track.id)\n",
" linked = True\n",
"\n",
" if not linked:\n",
" logging.info(f'Creating new track with ID {self.track_id}')\n",
" new_track = Track(self.track_id, detection)\n",
" new_track.linked_dets.append(Tracklet(timestamp, detection))\n",
" detected_obj_track_ids.add(self.track_id)\n",
" self.tracks.append(new_track)\n",
" self.track_id += 1\n",
"\n",
" for track in self.tracks:\n",
" # If the detector does not find the obj but estimated in the tracker, \n",
" # add the estimated one to that tracker's linked_dets\n",
" if track.id not in detected_obj_track_ids:\n",
" track.linked_dets.append(Tracklet(timestamp, track.det))\n",
"\n",
" self.prev_image = image\n",
" self.prev_time = timestamp\n",
"\n",
" if num_optical_flow_calls \u003e 0:\n",
" tracking_ms = int(1000 * (time.time() - start))\n",
" logging.info(f'Tracking took {tracking_ms}ms, '\n",
" f'{num_optical_flow_calls} optical flow calls')\n",
"\n",
" return self.tracks"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Kkg3SazB1edC"
},
"source": [
"Create a list of images to work on from the downloaded files."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "u0fqXQUzdZCu"
},
"outputs": [],
"source": [
"def embed_video_file(path: os.PathLike) -\u003e display.HTML:\n",
" \"\"\"Embeds a file in the notebook as an html tag with a data-url.\"\"\"\n",
" path = pathlib.Path(path)\n",
" mime, unused_encoding = mimetypes.guess_type(str(path))\n",
" data = path.read_bytes()\n",
"\n",
" b64 = base64.b64encode(data).decode()\n",
" return display.HTML(\n",
" textwrap.dedent(\"\"\"\n",
" \u003cvideo width=\"640\" height=\"480\" controls\u003e\n",
" \u003csource src=\"data:{mime};base64,{b64}\" type=\"{mime}\"\u003e\n",
" Your browser does not support the video tag.\n",
" \u003c/video\u003e\n",
" \"\"\").format(mime=mime, b64=b64))\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "kCdWsbO1afIJ"
},
"outputs": [],
"source": [
"tmp_video_path = \"tmp_preview.mp4\"\n",
"\n",
"filenames = sorted(glob.glob(f\"sample_images/{test_sequence_name}/*.jpg\"))\n",
"img = cv2.imread(filenames[0])\n",
"height, width, layers = img.shape\n",
"size = (width, height)\n",
"\n",
"video_writer = cv2.VideoWriter(\n",
" filename=tmp_video_path,\n",
" fourcc=cv2.VideoWriter_fourcc(*\"MP4V\"), \n",
" fps=15, \n",
" frameSize=size)\n",
" \n",
"for filename in tqdm(filenames):\n",
" img = cv2.imread(filename)\n",
" video_writer.write(img)\n",
"cv2.destroyAllWindows()\n",
"video_writer.release()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "cHsKpPyviWmF"
},
"source": [
"Re-encode the video, and reduce its size (Colab crashes if you try to embed the full size video)."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "_li0qe-gh1iT"
},
"outputs": [],
"source": [
"subprocess.check_call([\n",
" \"ffmpeg\", \"-y\", \"-i\", tmp_video_path,\n",
" \"-vf\",\"scale=800:-1\",\n",
" \"-crf\", \"18\",\n",
" \"-preset\", \"veryfast\",\n",
" \"-vcodec\", \"libx264\", preview_video_path])"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "2ItoiHyYQGya"
},
"source": [
"The images you downloaded are frames of a movie showing a top view of a coral reef with crown-of-thorns starfish. The movie looks like this:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"background_save": true
},
"id": "SiOsbr8xePkg"
},
"outputs": [],
"source": [
"embed_video_file(preview_video_path)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "9Z0DTbWrZMZ-"
},
"source": [
"The goal of the model is to put boxes around all of the starfish. Each starfish gets its own ID, and that ID will be stable as the camera passes over it."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "OW5gGixy1osE"
},
"source": [
"## Perform the COTS detection inference and tracking.\n",
"\n",
"The detection inference has the following four main steps:\n",
"1. Read all images in the order of image indexes and convert them into uint8 TF tensors (Line 45-54).\n",
"2. Feed the TF image tensors into the model (Line 61) and get the detection output `detections`. In particular, the shape of input tensor is [batch size, height, width, number of channels]. In this demo project, the input shape is [4, 1080, 1920, 3].\n",
"3. The inference output `detections` contains four variables: `num_detections` (the number of detected objects), `detection_boxes` (the coordinates of each COTS object's bounding box), `detection_classes` (the class label of each detected object), `detection_scores` (the confidence score of each detected COTS object).\n",
"4. To track the movement of each detected object across frames, in each frame's detection, the tracker will estimate each tracked COTS object's position if COTS is not detected.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"background_save": true
},
"id": "vHIarsxH1svL"
},
"outputs": [],
"source": [
"# Record all the detected COTS objects with the scores equal to or greater than the threshold\n",
"threshold = 0.4\n",
"_CLASS_ID_TO_LABEL = ('COTS',)\n",
"# Create a tracker object\n",
"tracker = OpticalFlowTracker(tid=1)\n",
"# Record tracking responses from the tracker\n",
"detection_result = []\n",
"# Record the length of each tracking sequence\n",
"track_length_dict = {}\n",
"\n",
"base_time = tf.timestamp()\n",
"\n",
"# Format tracker response, and save it into a new object.\n",
"def format_tracker_response(file_path, tracks, seq_length_dict):\n",
" new_track_list = []\n",
" for track in tracks:\n",
" detection_columns = [\n",
" _CLASS_ID_TO_LABEL[track.det.class_id],\n",
" str(track.det.score),\n",
" str(track.id),\n",
" str(len(track.linked_dets)),\n",
" str(round(track.det.x0)),\n",
" str(round(track.det.y0)),\n",
" str(round(track.det.x1 - track.det.x0)),\n",
" str(round(track.det.y1 - track.det.y0))\n",
" ]\n",
"\n",
" if str(track.id) not in seq_length_dict:\n",
" seq_length_dict[str(track.id)] = len(track.linked_dets)\n",
" else:\n",
" if len(track.linked_dets) \u003e seq_length_dict[str(track.id)]:\n",
" seq_length_dict[str(track.id)] = len(track.linked_dets)\n",
" new_track_list.append({\"score\":str(round(track.det.score, 3)), \"seq_id\": str(track.id), \"seq_idx\": str(len(track.linked_dets)),\n",
" \"x0\": round(track.det.x0), \"y0\": round(track.det.y0), \"x1\": round(track.det.x1), \"y1\": round(track.det.y1)})\n",
" return file_path, new_track_list, seq_length_dict\n",
"\n",
"# Read a jpg image and decode it to a uint8 tf tensor.\n",
"def parse_image(filename):\n",
" image = tf.io.read_file(filename)\n",
" image = tf.io.decode_jpeg(image)\n",
" return (tf.timestamp(), filename, image)\n",
"\n",
"# Create a data loader\n",
"file_list = sorted(glob.glob(f\"sample_images/{test_sequence_name}/*.jpg\"))\n",
"list_ds = tf.data.Dataset.from_tensor_slices(file_list)\n",
"images_ds = list_ds.map(parse_image)\n",
"\n",
"# Traverse the dataset with batch size = 1, you cannot change the batch size\n",
"for data in tqdm(images_ds.batch(1, drop_remainder=True)):\n",
" # timestamp is used for recording the order of frames\n",
" timestamp, file_path, image = data\n",
" timestamp = (timestamp - base_time) * 1000\n",
" # get detection result\n",
" detections = model_fn(image)\n",
" num_detections = detections['num_detections'].numpy().astype(np.int32)\n",
" detection_boxes = detections['detection_boxes'].numpy()\n",
" detection_classes = detections['detection_classes'].numpy().astype(np.int32)\n",
" detection_scores = detections['detection_scores'].numpy()\n",
"\n",
" batch_size, img_h, img_w = image.shape[0:3]\n",
"\n",
" for batch_index in range(batch_size):\n",
" valid_indices = detection_scores[batch_index, :] \u003e= threshold\n",
" classes = detection_classes[batch_index, valid_indices]\n",
" scores = detection_scores[batch_index, valid_indices]\n",
" boxes = detection_boxes[batch_index, valid_indices, :]\n",
" detections = []\n",
"\n",
" for class_id, score, box in zip(classes, scores, boxes):\n",
" detections.append(\n",
" Detection(\n",
" class_id=class_id,\n",
" score=score,\n",
" x0=box[1] * img_w,\n",
" y0=box[0] * img_h,\n",
" x1=box[3] * img_w,\n",
" y1=box[2] * img_h,\n",
" ))\n",
" # Feed detection results and the corresponding timestamp to the tracker, and then get tracker response\n",
" tracks = tracker.update(image[batch_index].numpy(), detections, timestamp[batch_index])\n",
" base_file_path, track_list, track_length_dict = format_tracker_response(file_path[batch_index].numpy().decode(\"utf-8\"), tracks, track_length_dict)\n",
" detection_result.append((base_file_path, track_list))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "QkpmYRyFAMlM"
},
"source": [
"# Output the detection results and play the result video\n",
"Once the inference is done, we use OpenCV to draw the bounding boxes (Line 9-10) and write the tracked COTS's information (Line 13-20: `COTS ID` `(sequence index/ sequence length)`) on each frame's image. Finally, we combine all frames into a video for visualisation."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"background_save": true
},
"id": "gWMJG7g95MGk"
},
"outputs": [],
"source": [
"detection_full_video_path = \"COTS_detection_full_size.mp4\"\n",
"detect_video_writer = cv2.VideoWriter(\n",
" filename=detection_full_video_path,\n",
" fourcc=cv2.VideoWriter_fourcc(*\"MP4V\"), \n",
" fps=15, \n",
" frameSize=size)\n",
"\n",
"for file_path, tracks in tqdm(detection_result):\n",
" image = cv2.imread(file_path)\n",
" for track in tracks:\n",
" # Draw the predicted bounding box\n",
" cv2.rectangle(image, (track['x0'], track['y0']),\n",
" (track['x1'], track['y1']),\n",
" (0, 140, 255), thickness=2,)\n",
" # Write the tracked COTS ID, and its corresponding tracking index and tracking sequence length\n",
" cv2.putText(image, f\"{track['seq_id']}\", (track['x0'], track['y0']-10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, (0, 140, 255), 2)\n",
" if len(track[\"seq_id\"]) == 1:\n",
" offset = 20\n",
" elif len(track[\"seq_id\"]) == 2:\n",
" offset = 40\n",
" else:\n",
" offset = 60\n",
" cv2.putText(image, \n",
" f\"({track['seq_idx']}/{track_length_dict[track['seq_id']]})\",\n",
" (track['x0'] + offset, track['y0']-10),\n",
" cv2.FONT_HERSHEY_SIMPLEX,\n",
" 0.6, (0, 140, 255), 2)\n",
" detect_video_writer.write(image)\n",
"cv2.destroyAllWindows()\n",
"\n",
"detect_video_writer.release()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"background_save": true
},
"id": "9s1myz67jcV8"
},
"outputs": [],
"source": [
"subprocess.check_call([\n",
" \"ffmpeg\",\"-y\", \"-i\", detection_full_video_path,\n",
" \"-vf\",\"scale=800:-1\",\n",
" \"-crf\", \"18\",\n",
" \"-preset\", \"veryfast\",\n",
" \"-vcodec\", \"libx264\", detection_small_video_path])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"background_save": true
},
"id": "wsK5cvX5jkL7"
},
"outputs": [],
"source": [
"embed_video_file(detection_small_video_path)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "n1oOgMR2zzIl"
},
"source": [
"The output video is now saved as movie at `detection_full_video_path`. You can download your video by uncommenting the following code."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"background_save": true
},
"id": "tyHucK8lbGXk"
},
"outputs": [],
"source": [
"#try:\n",
"# from google.colab import files\n",
"# files.download(detection_full_video_path)\n",
"#except ImportError:\n",
"# pass"
]
}
],
"metadata": {
"accelerator": "GPU",
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "xBH8CcrkV3IU"
},
"outputs": [],
"source": [
"#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
"# you may not use this file except in compliance with the License.\n",
"# You may obtain a copy of the License at\n",
"#\n",
"# https://www.apache.org/licenses/LICENSE-2.0\n",
"#\n",
"# Unless required by applicable law or agreed to in writing, software\n",
"# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
"# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
"# See the License for the specific language governing permissions and\n",
"# limitations under the License."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "9CzbXNRovpbc"
},
"source": [
"# Crown-of-Thorns Starfish Detection Pipeline"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Lpb0yoNjiWhw"
},
"source": [
"<table class=\"tfo-notebook-buttons\" align=\"left\">\n",
" <td>\n",
" <a target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/models/blob/master/official/projects/cots_detector/COTS_detection_inference_and_tracking.ipynb\"><img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Run in Google Colab</a>\n",
" </td>\n",
" <td>\n",
" <a target=\"_blank\" href=\"https://github.com/tensorflow/models/blob/master/official/projects/cots_detector/COTS_detection_inference_and_tracking.ipynb\"><img src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" />View on GitHub</a>\n",
" </td>\n",
"</table>"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "GUQ1x137ysLD"
},
"source": [
"Coral reefs are some of the most diverse and important ecosystems in the world , however they face a number of rising threats that have resulted in massive global declines. In Australia, outbreaks of the coral-eating crown-of-thorns starfish (COTS) have been shown to cause major coral loss, with just 15 starfish in a hectare being able to strip a reef of 90% of its coral tissue. While COTS naturally exist in the Indo-Pacific, overfishing and excess run-off nutrients have led to massive outbreaks that are devastating already vulnerable coral communities.\n",
"\n",
"Controlling COTS populations is critical to promoting coral growth and resilience, so Google teamed up with Australia’s national science agency, [CSIRO](https://www.csiro.au/en/), to tackle this problem. We trained ML object detection models to help scale underwater surveys, enabling the monitoring and mapping out these harmful invertebrates with the ultimate goal of helping control teams to address and prioritize outbreaks."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "jDiIX2xawkJw"
},
"source": [
"### This notebook\n",
"\n",
"This notebook tutorial shows how to detect COTS using a pre-trained COTS detector implemented in TensorFlow. On top of just running the model on each frame of the video, the tracking code in this notebook aligns detections from frame to frame creating a consistent track for each COTS. Each track is given an id and frame count. Here is an example image from a video of a reef showing labeled COTS starfish.\n",
"\n",
"<img src=\"https://storage.googleapis.com/download.tensorflow.org/data/cots_detection/COTS_detected_sample.png\">"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "YxCF1t-Skag8"
},
"source": [
"It is recommended to enable GPU to accelerate the inference. On CPU, this runs for about 40 minutes, but on GPU it takes only 10 minutes. (from colab menu: *Runtime > Change runtime type > Hardware accelerator > select \"GPU\"*)."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "a4R2T97u442o"
},
"source": [
"Install all needed packages."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "5Gs7XvCGlwlj"
},
"outputs": [],
"source": [
"# remove the existing datascience package to avoid package conflicts in the colab environment\n",
"!pip3 uninstall -y datascience\n",
"!pip3 install -q opencv-python"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "w-UQ87240x5R"
},
"outputs": [],
"source": [
"# Imports\n",
"import base64\n",
"import copy\n",
"import dataclasses\n",
"import glob\n",
"import logging\n",
"import mimetypes\n",
"import os\n",
"import pathlib\n",
"import subprocess\n",
"import time\n",
"import textwrap\n",
"\n",
"from absl import logging as absl_logging\n",
"from IPython import display\n",
"import cv2\n",
"import numpy as np\n",
"import tensorflow as tf\n",
"from tqdm import tqdm"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "gsSclJg4sJbX"
},
"source": [
"Define all needed variables."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "iKMCvnZEXBBT"
},
"outputs": [],
"source": [
"model_name = \"cots_1080_v1\" #@param [\"cots_1080_v1\", \"cots_720_v1\"]\n",
"test_sequence_name = \"test3\" #@param [\"test1\", \"test2\", \"test3\"]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "ORLJSdLq4-gd"
},
"outputs": [],
"source": [
"cots_model = f\"https://storage.googleapis.com/download.tensorflow.org/models/cots_detection/{model_name}.zip\"\n",
"\n",
"# Alternatively, this dataset can be downloaded through CSIRO's Data Access Portal at https://data.csiro.au/collection/csiro:54830v2\n",
"sample_data_link = f\"https://storage.googleapis.com/download.tensorflow.org/data/cots_detection/sample_images.zip\"\n",
"\n",
"preview_video_path = \"preview.mp4\"\n",
"detection_small_video_path = \"COTS_detection.mp4\"\n",
"detection_csv_path = \"detections.csv\""
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "fVq6vNBTxM62"
},
"source": [
"Also, download the trained COTS detection model that matches your preferences above."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "No5jRA1TxXj0"
},
"outputs": [],
"source": [
"model_path = tf.keras.utils.get_file(origin=cots_model)\n",
"# Unzip model\n",
"!mkdir {model_name}\n",
"!unzip -o -q {model_path} -d {model_name}"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "FNwP3s-5xgaF"
},
"source": [
"You also need to retrieve the sample data. This sample data is made up of a series of chronological images."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "DF_c_ZMXdPRN"
},
"outputs": [],
"source": [
"sample_data_path = tf.keras.utils.get_file(origin=sample_data_link)\n",
"# Unzip data\n",
"!mkdir sample_images\n",
"!unzip -o -q {sample_data_path} -d sample_images"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "d0iALUwM0g2p"
},
"source": [
"# Load the model and perform inference and tracking on sample data\n",
"Load trained model from disk and create the inference function `model_fn()`. This might take a little while."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "HXQnNjwl8Beu"
},
"outputs": [],
"source": [
"absl_logging.set_verbosity(absl_logging.ERROR)\n",
"\n",
"tf.config.optimizer.set_experimental_options({'auto_mixed_precision': True})\n",
"tf.config.optimizer.set_jit(True)\n",
"\n",
"model_fn = tf.saved_model.load(model_name).signatures['serving_default']"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "jbZ-7ICCENWG"
},
"source": [
"# Define **OpticalFlowTracker** class and its related classes\n",
"\n",
"These help track the movement of each COTS object throughout the image frames."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "tybwY3eaY803"
},
"outputs": [],
"source": [
"def box_area(x0, y0, x1, y1):\n",
" return (x1 - x0 + 1) * (y1 - y0 + 1)\n",
"\n",
"@dataclasses.dataclass\n",
"class Detection:\n",
" \"\"\"Detection dataclass.\"\"\"\n",
" class_id: int\n",
" score: float\n",
" x0: float\n",
" y0: float\n",
" x1: float\n",
" y1: float\n",
"\n",
" def __repr__(self):\n",
" return (f'Class {self.class_id}, score {self.score}, '\n",
" f'box ({self.x0}, {self.y0}, {self.x1}, {self.y1})')\n",
"\n",
" def area(self):\n",
" return box_area(self.x0, self.y0, self.x1, self.y1)\n",
"\n",
" def iou(self, other):\n",
" overlap_x0 = max(self.x0, other.x0)\n",
" overlap_y0 = max(self.y0, other.y0)\n",
" overlap_x1 = min(self.x1, other.x1)\n",
" overlap_y1 = min(self.y1, other.y1)\n",
" if overlap_x0 < overlap_x1 and overlap_y0 < overlap_y1:\n",
" overlap_area = box_area(overlap_x0, overlap_y0, overlap_x1,\n",
" overlap_y1)\n",
" return overlap_area / (self.area() + other.area() - overlap_area)\n",
" else:\n",
" return 0\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "gCQFfAkaY_WN"
},
"outputs": [],
"source": [
"class Tracklet:\n",
" def __init__(self, timestamp, detection):\n",
" self.timestamp = timestamp\n",
" # Store a copy here to make sure the coordinates will not be updated\n",
" # when the optical flow propagation runs using another reference to this\n",
" # detection.\n",
" self.detection = copy.deepcopy(detection)\n",
"\n",
" def __repr__(self):\n",
" return f'Time {self.timestamp}, ' + self.detection.__repr__()\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "7qVW1a_YZBgL"
},
"outputs": [],
"source": [
"class Track:\n",
" \"\"\"Tracker entries.\"\"\"\n",
" def __init__(self, id, detection):\n",
" self.id = id\n",
" self.linked_dets = []\n",
" self.det = detection\n",
"\n",
" def __repr__(self):\n",
" result = f'Track {self.id}'\n",
" for linked_det in self.linked_dets:\n",
" result += '\\n' + linked_det.__repr__()\n",
" return result\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "3j2Ka1uGEoz4"
},
"outputs": [],
"source": [
"class OpticalFlowTracker:\n",
" \"\"\"Optical flow tracker.\"\"\"\n",
" def __init__(self, tid, ft=3.0, iou=0.5, tt=2.0, bb=32, size=64, its=20,\n",
" eps=0.03, levels=3):\n",
" self.track_id = tid\n",
" # How long to apply optical flow tracking without getting positive \n",
" # detections (sec).\n",
" self.track_flow_time = ft * 1000\n",
" # Required IoU overlap to link a detection to a track.\n",
" self.overlap_threshold = iou\n",
" # Used to detect if detector needs to be reset.\n",
" self.time_threshold = tt * 1000\n",
" self.border = bb\n",
" # Size of optical flow region.\n",
" self.of_size = (size, size)\n",
" self.of_criteria = (cv2.TermCriteria_COUNT + cv2.TermCriteria_EPS, its, \n",
" eps)\n",
" self.of_levels= levels\n",
"\n",
" self.tracks = []\n",
" self.prev_image = None\n",
" self.prev_time = -1\n",
"\n",
" def update(self, image_bgr, detections, timestamp):\n",
" start = time.time()\n",
" num_optical_flow_calls = 0\n",
"\n",
" image = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2GRAY)\n",
"\n",
" image_w = image.shape[1]\n",
" image_h = image.shape[0]\n",
"\n",
" # Assume tracker is invalid if too much time has passed!\n",
" if (self.prev_time > 0 and\n",
" timestamp - self.prev_time > self.time_threshold):\n",
" logging.info(\n",
" 'Too much time since last update, resetting tracker.')\n",
" self.tracks = []\n",
"\n",
" # Remove tracks which are:\n",
" # - Touching the image edge.\n",
" # - Have existed for a long time without linking a real detection.\n",
" active_tracks = []\n",
" for track in self.tracks:\n",
" if (track.det.x0 < self.border or track.det.y0 < self.border or\n",
" track.det.x1 >= (image_w - self.border) or\n",
" track.det.y1 >= (image_h - self.border)):\n",
" logging.info(f'Removing track {track.id} because it\\'s near the border')\n",
" continue\n",
"\n",
" time_since_last_detection = timestamp - track.linked_dets[-1].timestamp\n",
" if (time_since_last_detection > self.track_flow_time):\n",
" logging.info(f'Removing track {track.id} because it\\'s too old '\n",
" f'({time_since_last_detection:.02f}s)')\n",
" continue\n",
"\n",
" active_tracks.append(track)\n",
"\n",
" self.tracks = active_tracks\n",
"\n",
" # Run optical flow to update existing tracks.\n",
" if self.prev_time > 0:\n",
" # print('Running optical flow propagation.')\n",
" of_params = {\n",
" 'winSize': self.of_size,\n",
" 'maxLevel': self.of_levels,\n",
" 'criteria': self.of_criteria\n",
" }\n",
" for track in self.tracks:\n",
" input_points = np.float32([[[(track.det.x0 + track.det.x1) / 2,\n",
" (track.det.y0 + track.det.y1) / 2]]])\n",
" output_points, status, error = cv2.calcOpticalFlowPyrLK(\n",
" self.prev_image, image, input_points, None, **of_params)\n",
" num_optical_flow_calls += 1\n",
" w = track.det.x1 - track.det.x0\n",
" h = track.det.y1 - track.det.y0\n",
" # print(f'Detection before flow update: {track.det}')\n",
" track.det.x0 = output_points[0][0][0] - w * 0.5\n",
" track.det.y0 = output_points[0][0][1] - h * 0.5\n",
" track.det.x1 = output_points[0][0][0] + w * 0.5\n",
" track.det.y1 = output_points[0][0][1] + h * 0.5\n",
" # print(f'Detection after flow update: {track.det}')\n",
"\n",
"\n",
" # Insert new detections.\n",
" detected_obj_track_ids = set()\n",
"\n",
" for detection in detections:\n",
" if (detection.x0 < self.border or detection.y0 < self.border or\n",
" detection.x1 >= image_w - self.border or\n",
" detection.y1 >= image_h - self.border):\n",
" # print('Skipping detection because it\\'s close to the border.')\n",
" continue\n",
"\n",
" # See if detection can be linked to an existing track.\n",
" linked = False\n",
" overlap_index = 0\n",
" overlap_max = -1000\n",
" for track_index, track in enumerate(self.tracks):\n",
" # print(f'Testing track {track_index}')\n",
" if track.det.class_id != detection.class_id:\n",
" continue\n",
" overlap = detection.iou(track.det)\n",
" if overlap > overlap_max:\n",
" overlap_index = track_index\n",
" overlap_max = overlap\n",
"\n",
" # Link to existing track with maximal IoU.\n",
" if overlap_max > self.overlap_threshold:\n",
" track = self.tracks[overlap_index]\n",
" track.det = detection\n",
" track.linked_dets.append(Tracklet(timestamp, detection))\n",
" detected_obj_track_ids.add(track.id)\n",
" linked = True\n",
"\n",
" if not linked:\n",
" logging.info(f'Creating new track with ID {self.track_id}')\n",
" new_track = Track(self.track_id, detection)\n",
" new_track.linked_dets.append(Tracklet(timestamp, detection))\n",
" detected_obj_track_ids.add(self.track_id)\n",
" self.tracks.append(new_track)\n",
" self.track_id += 1\n",
"\n",
" for track in self.tracks:\n",
" # If the detector does not find the obj but estimated in the tracker, \n",
" # add the estimated one to that tracker's linked_dets\n",
" if track.id not in detected_obj_track_ids:\n",
" track.linked_dets.append(Tracklet(timestamp, track.det))\n",
"\n",
" self.prev_image = image\n",
" self.prev_time = timestamp\n",
"\n",
" if num_optical_flow_calls > 0:\n",
" tracking_ms = int(1000 * (time.time() - start))\n",
" logging.info(f'Tracking took {tracking_ms}ms, '\n",
" f'{num_optical_flow_calls} optical flow calls')\n",
"\n",
" return self.tracks"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Kkg3SazB1edC"
},
"source": [
"Create a list of images to work on from the downloaded files."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "u0fqXQUzdZCu"
},
"outputs": [],
"source": [
"def embed_video_file(path: os.PathLike) -> display.HTML:\n",
" \"\"\"Embeds a file in the notebook as an html tag with a data-url.\"\"\"\n",
" path = pathlib.Path(path)\n",
" mime, unused_encoding = mimetypes.guess_type(str(path))\n",
" data = path.read_bytes()\n",
"\n",
" b64 = base64.b64encode(data).decode()\n",
" return display.HTML(\n",
" textwrap.dedent(\"\"\"\n",
" <video width=\"640\" height=\"480\" controls>\n",
" <source src=\"data:{mime};base64,{b64}\" type=\"{mime}\">\n",
" Your browser does not support the video tag.\n",
" </video>\n",
" \"\"\").format(mime=mime, b64=b64))\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "kCdWsbO1afIJ"
},
"outputs": [],
"source": [
"tmp_video_path = \"tmp_preview.mp4\"\n",
"\n",
"filenames = sorted(glob.glob(f\"sample_images/{test_sequence_name}/*.jpg\"))\n",
"img = cv2.imread(filenames[0])\n",
"height, width, layers = img.shape\n",
"size = (width, height)\n",
"\n",
"video_writer = cv2.VideoWriter(\n",
" filename=tmp_video_path,\n",
" fourcc=cv2.VideoWriter_fourcc(*\"MP4V\"), \n",
" fps=15, \n",
" frameSize=size)\n",
" \n",
"for filename in tqdm(filenames):\n",
" img = cv2.imread(filename)\n",
" video_writer.write(img)\n",
"cv2.destroyAllWindows()\n",
"video_writer.release()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "cHsKpPyviWmF"
},
"source": [
"Re-encode the video, and reduce its size (Colab crashes if you try to embed the full size video)."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "_li0qe-gh1iT"
},
"outputs": [],
"source": [
"subprocess.check_call([\n",
" \"ffmpeg\", \"-y\", \"-i\", tmp_video_path,\n",
" \"-vf\",\"scale=800:-1\",\n",
" \"-crf\", \"18\",\n",
" \"-preset\", \"veryfast\",\n",
" \"-vcodec\", \"libx264\", preview_video_path])"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "2ItoiHyYQGya"
},
"source": [
"The images you downloaded are frames of a movie showing a top view of a coral reef with crown-of-thorns starfish. The movie looks like this:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"collapsed_sections": [],
"name": "COTS Detection Inference and Tracking Pipeline.ipynb",
"provenance": []
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
},
"language_info": {
"name": "python"
}
"background_save": true
},
"id": "SiOsbr8xePkg"
},
"outputs": [],
"source": [
"embed_video_file(preview_video_path)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "9Z0DTbWrZMZ-"
},
"source": [
"The goal of the model is to put boxes around all of the starfish. Each starfish gets its own ID, and that ID will be stable as the camera passes over it."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "OW5gGixy1osE"
},
"source": [
"## Perform the COTS detection inference and tracking.\n",
"\n",
"The detection inference has the following four main steps:\n",
"1. Read all images in the order of image indexes and convert them into uint8 TF tensors (Line 45-54).\n",
"2. Feed the TF image tensors into the model (Line 61) and get the detection output `detections`. In particular, the shape of input tensor is [batch size, height, width, number of channels]. In this demo project, the input shape is [4, 1080, 1920, 3].\n",
"3. The inference output `detections` contains four variables: `num_detections` (the number of detected objects), `detection_boxes` (the coordinates of each COTS object's bounding box), `detection_classes` (the class label of each detected object), `detection_scores` (the confidence score of each detected COTS object).\n",
"4. To track the movement of each detected object across frames, in each frame's detection, the tracker will estimate each tracked COTS object's position if COTS is not detected.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"background_save": true
},
"id": "vHIarsxH1svL"
},
"outputs": [],
"source": [
"# Record all the detected COTS objects with the scores equal to or greater than the threshold\n",
"threshold = 0.4\n",
"_CLASS_ID_TO_LABEL = ('COTS',)\n",
"# Create a tracker object\n",
"tracker = OpticalFlowTracker(tid=1)\n",
"# Record tracking responses from the tracker\n",
"detection_result = []\n",
"# Record the length of each tracking sequence\n",
"track_length_dict = {}\n",
"\n",
"base_time = tf.timestamp()\n",
"\n",
"# Format tracker response, and save it into a new object.\n",
"def format_tracker_response(file_path, tracks, seq_length_dict):\n",
" new_track_list = []\n",
" for track in tracks:\n",
" detection_columns = [\n",
" _CLASS_ID_TO_LABEL[track.det.class_id],\n",
" str(track.det.score),\n",
" str(track.id),\n",
" str(len(track.linked_dets)),\n",
" str(round(track.det.x0)),\n",
" str(round(track.det.y0)),\n",
" str(round(track.det.x1 - track.det.x0)),\n",
" str(round(track.det.y1 - track.det.y0))\n",
" ]\n",
"\n",
" if str(track.id) not in seq_length_dict:\n",
" seq_length_dict[str(track.id)] = len(track.linked_dets)\n",
" else:\n",
" if len(track.linked_dets) > seq_length_dict[str(track.id)]:\n",
" seq_length_dict[str(track.id)] = len(track.linked_dets)\n",
" new_track_list.append({\"score\":str(round(track.det.score, 3)), \"seq_id\": str(track.id), \"seq_idx\": str(len(track.linked_dets)),\n",
" \"x0\": round(track.det.x0), \"y0\": round(track.det.y0), \"x1\": round(track.det.x1), \"y1\": round(track.det.y1)})\n",
" return file_path, new_track_list, seq_length_dict\n",
"\n",
"# Read a jpg image and decode it to a uint8 tf tensor.\n",
"def parse_image(filename):\n",
" image = tf.io.read_file(filename)\n",
" image = tf.io.decode_jpeg(image)\n",
" return (tf.timestamp(), filename, image)\n",
"\n",
"# Create a data loader\n",
"file_list = sorted(glob.glob(f\"sample_images/{test_sequence_name}/*.jpg\"))\n",
"list_ds = tf.data.Dataset.from_tensor_slices(file_list)\n",
"images_ds = list_ds.map(parse_image)\n",
"\n",
"# Traverse the dataset with batch size = 1, you cannot change the batch size\n",
"for data in tqdm(images_ds.batch(1, drop_remainder=True)):\n",
" # timestamp is used for recording the order of frames\n",
" timestamp, file_path, image = data\n",
" timestamp = (timestamp - base_time) * 1000\n",
" # get detection result\n",
" detections = model_fn(image)\n",
" num_detections = detections['num_detections'].numpy().astype(np.int32)\n",
" detection_boxes = detections['detection_boxes'].numpy()\n",
" detection_classes = detections['detection_classes'].numpy().astype(np.int32)\n",
" detection_scores = detections['detection_scores'].numpy()\n",
"\n",
" batch_size, img_h, img_w = image.shape[0:3]\n",
"\n",
" for batch_index in range(batch_size):\n",
" valid_indices = detection_scores[batch_index, :] >= threshold\n",
" classes = detection_classes[batch_index, valid_indices]\n",
" scores = detection_scores[batch_index, valid_indices]\n",
" boxes = detection_boxes[batch_index, valid_indices, :]\n",
" detections = []\n",
"\n",
" for class_id, score, box in zip(classes, scores, boxes):\n",
" detections.append(\n",
" Detection(\n",
" class_id=class_id,\n",
" score=score,\n",
" x0=box[1] * img_w,\n",
" y0=box[0] * img_h,\n",
" x1=box[3] * img_w,\n",
" y1=box[2] * img_h,\n",
" ))\n",
" # Feed detection results and the corresponding timestamp to the tracker, and then get tracker response\n",
" tracks = tracker.update(image[batch_index].numpy(), detections, timestamp[batch_index])\n",
" base_file_path, track_list, track_length_dict = format_tracker_response(file_path[batch_index].numpy().decode(\"utf-8\"), tracks, track_length_dict)\n",
" detection_result.append((base_file_path, track_list))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "QkpmYRyFAMlM"
},
"source": [
"# Output the detection results and play the result video\n",
"Once the inference is done, we use OpenCV to draw the bounding boxes (Line 9-10) and write the tracked COTS's information (Line 13-20: `COTS ID` `(sequence index/ sequence length)`) on each frame's image. Finally, we combine all frames into a video for visualisation."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"background_save": true
},
"id": "gWMJG7g95MGk"
},
"outputs": [],
"source": [
"detection_full_video_path = \"COTS_detection_full_size.mp4\"\n",
"detect_video_writer = cv2.VideoWriter(\n",
" filename=detection_full_video_path,\n",
" fourcc=cv2.VideoWriter_fourcc(*\"MP4V\"), \n",
" fps=15, \n",
" frameSize=size)\n",
"\n",
"for file_path, tracks in tqdm(detection_result):\n",
" image = cv2.imread(file_path)\n",
" for track in tracks:\n",
" # Draw the predicted bounding box\n",
" cv2.rectangle(image, (track['x0'], track['y0']),\n",
" (track['x1'], track['y1']),\n",
" (0, 140, 255), thickness=2,)\n",
" # Write the tracked COTS ID, and its corresponding tracking index and tracking sequence length\n",
" cv2.putText(image, f\"{track['seq_id']}\", (track['x0'], track['y0']-10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, (0, 140, 255), 2)\n",
" if len(track[\"seq_id\"]) == 1:\n",
" offset = 20\n",
" elif len(track[\"seq_id\"]) == 2:\n",
" offset = 40\n",
" else:\n",
" offset = 60\n",
" cv2.putText(image, \n",
" f\"({track['seq_idx']}/{track_length_dict[track['seq_id']]})\",\n",
" (track['x0'] + offset, track['y0']-10),\n",
" cv2.FONT_HERSHEY_SIMPLEX,\n",
" 0.6, (0, 140, 255), 2)\n",
" detect_video_writer.write(image)\n",
"cv2.destroyAllWindows()\n",
"\n",
"detect_video_writer.release()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"background_save": true
},
"id": "9s1myz67jcV8"
},
"outputs": [],
"source": [
"subprocess.check_call([\n",
" \"ffmpeg\",\"-y\", \"-i\", detection_full_video_path,\n",
" \"-vf\",\"scale=800:-1\",\n",
" \"-crf\", \"18\",\n",
" \"-preset\", \"veryfast\",\n",
" \"-vcodec\", \"libx264\", detection_small_video_path])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"background_save": true
},
"id": "wsK5cvX5jkL7"
},
"outputs": [],
"source": [
"embed_video_file(detection_small_video_path)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "n1oOgMR2zzIl"
},
"source": [
"The output video is now saved as movie at `detection_full_video_path`. You can download your video by uncommenting the following code."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"background_save": true
},
"id": "tyHucK8lbGXk"
},
"outputs": [],
"source": [
"#try:\n",
"# from google.colab import files\n",
"# files.download(detection_full_video_path)\n",
"#except ImportError:\n",
"# pass"
]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"collapsed_sections": [],
"name": "COTS Detection Inference and Tracking Pipeline.ipynb",
"provenance": []
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
},
"nbformat": 4,
"nbformat_minor": 0
"language_info": {
"name": "python"
}
},
"nbformat": 4,
"nbformat_minor": 0
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment