Commit ee4405c2 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Merge pull request #10640 from MarkDaoust:COTS

PiperOrigin-RevId: 453490561
parents 1f51347b b05de1c1
...@@ -37,14 +37,14 @@ ...@@ -37,14 +37,14 @@
"id": "Lpb0yoNjiWhw" "id": "Lpb0yoNjiWhw"
}, },
"source": [ "source": [
"<table class=\"tfo-notebook-buttons\" align=\"left\">\n", "\u003ctable class=\"tfo-notebook-buttons\" align=\"left\"\u003e\n",
" <td>\n", " \u003ctd\u003e\n",
" <a target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/models/blob/master/official/projects/cots_detector/crown_of_thorns_starfish_detection_pipeline.ipynb\"><img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Run in Google Colab</a>\n", " \u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/models/blob/master/official/projects/cots_detector/crown_of_thorns_starfish_detection_pipeline.ipynb?force_crab_mode=1\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e\n",
" </td>\n", " \u003c/td\u003e\n",
" <td>\n", " \u003ctd\u003e\n",
" <a target=\"_blank\" href=\"https://github.com/tensorflow/models/blob/master/official/projects/cots_detector/crown_of_thorns_starfish_detection_pipeline.ipynb\"><img src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" />View on GitHub</a>\n", " \u003ca target=\"_blank\" href=\"https://github.com/tensorflow/models/blob/master/official/projects/cots_detector/crown_of_thorns_starfish_detection_pipeline.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView on GitHub\u003c/a\u003e\n",
" </td>\n", " \u003c/td\u003e\n",
"</table>" "\u003c/table\u003e"
] ]
}, },
{ {
...@@ -64,11 +64,11 @@ ...@@ -64,11 +64,11 @@
"id": "jDiIX2xawkJw" "id": "jDiIX2xawkJw"
}, },
"source": [ "source": [
"### This notebook\n", "## About this notebook\n",
"\n", "\n",
"This notebook tutorial shows how to detect COTS using a pre-trained COTS detector implemented in TensorFlow. On top of just running the model on each frame of the video, the tracking code in this notebook aligns detections from frame to frame creating a consistent track for each COTS. Each track is given an id and frame count. Here is an example image from a video of a reef showing labeled COTS starfish.\n", "This notebook tutorial shows how to detect COTS using a pre-trained COTS detector implemented in TensorFlow. On top of just running the model on each frame of the video, the tracking code in this notebook aligns detections from frame to frame creating a consistent track for each COTS. Each track is given an id and frame count. Here is an example image from a video of a reef showing labeled COTS starfish.\n",
"\n", "\n",
"<img src=\"https://storage.googleapis.com/download.tensorflow.org/data/cots_detection/COTS_detected_sample.png\">" "\u003cimg src=\"https://storage.googleapis.com/download.tensorflow.org/data/cots_detection/COTS_detected_sample.png\"\u003e"
] ]
}, },
{ {
...@@ -77,7 +77,7 @@ ...@@ -77,7 +77,7 @@
"id": "YxCF1t-Skag8" "id": "YxCF1t-Skag8"
}, },
"source": [ "source": [
"It is recommended to enable GPU to accelerate the inference. On CPU, this runs for about 40 minutes, but on GPU it takes only 10 minutes. (from colab menu: *Runtime > Change runtime type > Hardware accelerator > select \"GPU\"*)." "It is recommended to enable GPU to accelerate the inference. On CPU, this runs for about 40 minutes, but on GPU it takes only 10 minutes. (In Colab it should already be set to GPU in the Runtime menu: *Runtime \u003e Change runtime type \u003e Hardware accelerator \u003e select \"GPU\"*)."
] ]
}, },
{ {
...@@ -86,6 +86,8 @@ ...@@ -86,6 +86,8 @@
"id": "a4R2T97u442o" "id": "a4R2T97u442o"
}, },
"source": [ "source": [
"## Setup \n",
"\n",
"Install all needed packages." "Install all needed packages."
] ]
}, },
...@@ -99,7 +101,8 @@ ...@@ -99,7 +101,8 @@
"source": [ "source": [
"# remove the existing datascience package to avoid package conflicts in the colab environment\n", "# remove the existing datascience package to avoid package conflicts in the colab environment\n",
"!pip3 uninstall -y datascience\n", "!pip3 uninstall -y datascience\n",
"!pip3 install -q opencv-python" "!pip3 install -q opencv-python\n",
"!pip3 install PILLOW"
] ]
}, },
{ {
...@@ -122,11 +125,14 @@ ...@@ -122,11 +125,14 @@
"import subprocess\n", "import subprocess\n",
"import time\n", "import time\n",
"import textwrap\n", "import textwrap\n",
"from typing import Dict, Iterable, List, Optional, Tuple\n",
"\n", "\n",
"from absl import logging as absl_logging\n", "from absl import logging as absl_logging\n",
"from IPython import display\n", "from IPython import display\n",
"import cv2\n", "import cv2\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n", "import numpy as np\n",
"import PIL.Image\n",
"import tensorflow as tf\n", "import tensorflow as tf\n",
"from tqdm import tqdm" "from tqdm import tqdm"
] ]
...@@ -170,13 +176,160 @@ ...@@ -170,13 +176,160 @@
"detection_csv_path = \"detections.csv\"" "detection_csv_path = \"detections.csv\""
] ]
}, },
{
"cell_type": "markdown",
"metadata": {
"id": "FNwP3s-5xgaF"
},
"source": [
"You also need to retrieve the sample data. This sample data is made up of a series of chronological images."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "DF_c_ZMXdPRN"
},
"outputs": [],
"source": [
"sample_data_path = tf.keras.utils.get_file(origin=sample_data_link)\n",
"# Unzip data\n",
"!mkdir sample_images\n",
"!unzip -o -q {sample_data_path} -d sample_images"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Ghf-4E5-ZiJn"
},
"source": [
"Convert the images to a video file:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "kCdWsbO1afIJ"
},
"outputs": [],
"source": [
"tmp_video_path = \"tmp_preview.mp4\"\n",
"\n",
"filenames = sorted(glob.glob(f\"sample_images/{test_sequence_name}/*.jpg\"))\n",
"img = cv2.imread(filenames[0])\n",
"height, width, layers = img.shape\n",
"size = (width, height)\n",
"\n",
"video_writer = cv2.VideoWriter(\n",
" filename=tmp_video_path,\n",
" fourcc=cv2.VideoWriter_fourcc(*\"MP4V\"), \n",
" fps=15, \n",
" frameSize=size)\n",
" \n",
"for filename in tqdm(filenames):\n",
" img = cv2.imread(filename)\n",
" video_writer.write(img)\n",
"cv2.destroyAllWindows()\n",
"video_writer.release()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "cHsKpPyviWmF"
},
"source": [
"Re-encode the video, and reduce its size (Colab crashes if you try to embed the full size video)."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "_li0qe-gh1iT"
},
"outputs": [],
"source": [
"subprocess.check_call([\n",
" \"ffmpeg\", \"-y\", \"-i\", tmp_video_path,\n",
" \"-vf\",\"scale=800:-1\",\n",
" \"-crf\", \"18\",\n",
" \"-preset\", \"veryfast\",\n",
" \"-vcodec\", \"libx264\", preview_video_path])"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "2ItoiHyYQGya"
},
"source": [
"The images you downloaded are frames of a movie showing a top view of a coral reef with crown-of-thorns starfish. Use the `base64` data-URL trick to embed the video in this notebook:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "u0fqXQUzdZCu"
},
"outputs": [],
"source": [
"def embed_video_file(path: os.PathLike) -\u003e display.HTML:\n",
" \"\"\"Embeds a file in the notebook as an html tag with a data-url.\"\"\"\n",
" path = pathlib.Path(path)\n",
" mime, unused_encoding = mimetypes.guess_type(str(path))\n",
" data = path.read_bytes()\n",
"\n",
" b64 = base64.b64encode(data).decode()\n",
" return display.HTML(\n",
" textwrap.dedent(\"\"\"\n",
" \u003cvideo width=\"640\" height=\"480\" controls\u003e\n",
" \u003csource src=\"data:{mime};base64,{b64}\" type=\"{mime}\"\u003e\n",
" Your browser does not support the video tag.\n",
" \u003c/video\u003e\n",
" \"\"\").format(mime=mime, b64=b64))\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "SiOsbr8xePkg"
},
"outputs": [],
"source": [
"embed_video_file(preview_video_path)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "9Z0DTbWrZMZ-"
},
"source": [
"Can you se them? there are lots. The goal of the model is to put boxes around all of the starfish. Each starfish will get its own ID, and that ID will be stable as the camera passes over it."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "d0iALUwM0g2p"
},
"source": [
"## Load the model"
]
},
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"id": "fVq6vNBTxM62" "id": "fVq6vNBTxM62"
}, },
"source": [ "source": [
"Also, download the trained COTS detection model that matches your preferences above." "Download the trained COTS detection model that matches your preferences from earlier."
] ]
}, },
{ {
...@@ -196,246 +349,736 @@ ...@@ -196,246 +349,736 @@
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"id": "FNwP3s-5xgaF" "id": "ezyuSHK5ap__"
}, },
"source": [ "source": [
"You also need to retrieve the sample data. This sample data is made up of a series of chronological images." "Load trained model from disk and create the inference function `model_fn()`. This might take a little while."
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {
"id": "DF_c_ZMXdPRN" "id": "HXQnNjwl8Beu"
}, },
"outputs": [], "outputs": [],
"source": [ "source": [
"sample_data_path = tf.keras.utils.get_file(origin=sample_data_link)\n", "absl_logging.set_verbosity(absl_logging.ERROR)\n",
"# Unzip data\n", "\n",
"!mkdir sample_images\n", "tf.config.optimizer.set_experimental_options({'auto_mixed_precision': True})\n",
"!unzip -o -q {sample_data_path} -d sample_images" "tf.config.optimizer.set_jit(True)\n",
"\n",
"model_fn = tf.saved_model.load(model_name).signatures['serving_default']"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "OvLuznhUa7uG"
},
"source": [
"Here's one test image; how many COTS can you see?"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "XmQF_2L_a7Hu"
},
"outputs": [],
"source": [
"example_frame_number = 52\n",
"image = tf.io.read_file(filenames[example_frame_number])\n",
"image = tf.io.decode_jpeg(image)\n",
"\n",
"# Caution PIL and tf use \"RGB\" color order, while cv2 uses \"BGR\".\n",
"PIL.Image.fromarray(image.numpy())"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "KSOf4V8WhTHF"
},
"source": [
"## Raw model outputs\n",
"\n",
"Try running the model on the image. The model expects a batch of images so add an outer `batch` dimension before calling the model.\n",
"\n",
"Note: The model only runs correctly with a batch size of 1.\n",
"\n",
"The result is a dictionary with a number of fields. For all fields the first dimension of the shape is the `batch` dimension, "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "iqLHo8h0c2pW"
},
"outputs": [],
"source": [
"image_batch = image[tf.newaxis, ...]\n",
"result = model_fn(image_batch)\n",
"\n",
"print(f\"{'image_batch':20s}- shape: {image_batch.shape}\")\n",
"\n",
"for key, value in result.items():\n",
" print(f\"{key:20s}- shape: {value.shape}\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "0xuNoKLCjyDz"
},
"source": [
"The `num_detections` field gives the number of valid detections, but this is always 100. There are always 100 locations that _could_ be a COTS."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "nGCDZJQvkIOL"
},
"outputs": [],
"source": [
"print('\\nnum_detections: ', result['num_detections'].numpy())"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "cSd7JJYqkPz7"
},
"source": [
"Similarly the `detection_classes` field is always `0`, since the model only detects 1 class: COTS."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "JoY8bJrfkcuS"
},
"outputs": [],
"source": [
"print('detection_classes: \\n', result['detection_classes'].numpy())"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "X2nVLSOokyog"
},
"source": [
"What actually matters here is the detection scores, indicating the quality of each detection: "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "iepEgCc2jsRD"
},
"outputs": [],
"source": [
"result['detection_scores'].numpy()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Fn2B0nbplAFy"
},
"source": [
"You need to choose a threshold that determines what counts as a good detection. This frame has a few good detections:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "a30Uyc0WlK2a"
},
"outputs": [],
"source": [
"good_detections = result['detection_scores'] \u003e 0.4\n",
"good_detections.numpy()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Y_xrbQiAlWrK"
},
"source": [
"## Bounding boxes and detections\n",
"\n",
"Build a class to handle the detection boxes:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "S5inzqu-4JhT"
},
"outputs": [],
"source": [
"@dataclasses.dataclass(frozen=True)\n",
"class BBox:\n",
" x0: float\n",
" y0: float\n",
" x1: float\n",
" y1: float\n",
"\n",
" def replace(self, **kwargs):\n",
" d = self.__dict__.copy()\n",
" d.update(kwargs)\n",
" return type(self)(**d)\n",
"\n",
" @property\n",
" def center(self)-\u003e Tuple[float, float]:\n",
" return ((self.x0+self.x1)/2, (self.y0+self.y1)/2)\n",
" \n",
" @property\n",
" def width(self) -\u003e float:\n",
" return self.x1 - self.x0\n",
"\n",
" @property\n",
" def height(self) -\u003e float:\n",
" return self.y1 - self.y0\n",
"\n",
" @property\n",
" def area(self)-\u003e float:\n",
" return (self.x1 - self.x0 + 1) * (self.y1 - self.y0 + 1)\n",
" \n",
" def intersection(self, other)-\u003e Optional['BBox']:\n",
" x0 = max(self.x0, other.x0)\n",
" y0 = max(self.y0, other.y0)\n",
" x1 = min(self.x1, other.x1)\n",
" y1 = min(self.y1, other.y1)\n",
" if x0 \u003e x1 or y0 \u003e y1:\n",
" return None\n",
" return BBox(x0, y0, x1, y1)\n",
"\n",
" def iou(self, other):\n",
" intersection = self.intersection(other)\n",
" if intersection is None:\n",
" return 0\n",
" \n",
" ia = intersection.area\n",
"\n",
" return ia/(self.area + other.area - ia)\n",
" \n",
" def draw(self, image, label=None, color=(0, 140, 255)):\n",
" image = np.asarray(image)\n",
" cv2.rectangle(image, \n",
" (int(self.x0), int(self.y0)),\n",
" (int(self.x1), int(self.y1)),\n",
" color,\n",
" thickness=2)\n",
" if label is not None:\n",
" cv2.putText(image, str(label), \n",
" (int(self.x0), int(self.y0-10)),\n",
" cv2.FONT_HERSHEY_SIMPLEX,\n",
" 0.9, color, thickness=2)\n",
" return image"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "2izYMR9Q6Dn0"
},
"source": [
"And a class to represent a `Detection`, with a method to create a list of detections from the model's output:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "tybwY3eaY803"
},
"outputs": [],
"source": [
"@dataclasses.dataclass(frozen=True)\n",
"class Detection:\n",
" \"\"\"Detection dataclass.\"\"\"\n",
" class_id: int\n",
" score: float\n",
" bbox: BBox\n",
" threshold:float = 0.4\n",
"\n",
" def replace(self, **kwargs):\n",
" d = self.__dict__.copy()\n",
" d.update(kwargs)\n",
" return type(self)(**d)\n",
"\n",
" @classmethod\n",
" def process_model_output(\n",
" cls, image, detections: Dict[str, tf.Tensor]\n",
" ) -\u003e Iterable['Detection']:\n",
" \n",
" # The model only works on a batch size of 1.\n",
" detection_boxes = detections['detection_boxes'].numpy()[0]\n",
" detection_classes = detections['detection_classes'].numpy()[0].astype(np.int32)\n",
" detection_scores = detections['detection_scores'].numpy()[0]\n",
"\n",
" img_h, img_w = image.shape[0:2]\n",
"\n",
" valid_indices = detection_scores \u003e= cls.threshold\n",
" classes = detection_classes[valid_indices]\n",
" scores = detection_scores[valid_indices]\n",
" boxes = detection_boxes[valid_indices, :]\n",
" detections = []\n",
"\n",
" for class_id, score, box in zip(classes, scores, boxes):\n",
" detections.append(\n",
" Detection(\n",
" class_id=class_id,\n",
" score=score,\n",
" bbox=BBox(\n",
" x0=box[1] * img_w,\n",
" y0=box[0] * img_h,\n",
" x1=box[3] * img_w,\n",
" y1=box[2] * img_h,)))\n",
"\n",
" return detections"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "QRZ9Q5meHl84"
},
"source": [
"## Preview some detections\n",
"\n",
"Now you can preview the model's output:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Px7AoFCn-psx"
},
"outputs": [],
"source": [
"detections = Detection.process_model_output(image, result)\n",
"\n",
"for n, det in enumerate(detections):\n",
" det.bbox.draw(image, label=n+1, color=(255, 140, 0))\n",
"\n",
"PIL.Image.fromarray(image.numpy())"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "B1q_n1xJLm60"
},
"source": [
"That works well for one frame, but to count the number of COTS in a video you'll need to track the detections from frame to frame. The raw detection indices are not stable, they're just sorted by the detection score. Below both sets of detections are overlaid on the second image with the first frame's detections in white and the second frame's in orange, the indices are not aligned. The positions are shifted because of camera motion between the two frames:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "PLtxJFPuLma0"
},
"outputs": [],
"source": [
"image2 = tf.io.read_file(filenames[example_frame_number+5]) # five frames later\n",
"image2 = tf.io.decode_jpeg(image2)\n",
"result2 = model_fn(image2[tf.newaxis, ...])\n",
"detections2 = Detection.process_model_output(image2, result2)\n",
"\n",
"for n, det in enumerate(detections):\n",
" det.bbox.draw(image2, label=n+1, color=(255, 255, 255))\n",
"\n",
"for n, det in enumerate(detections2):\n",
" det.bbox.draw(image2, label=n+1, color=(255, 140, 0))\n",
"\n",
"PIL.Image.fromarray(image2.numpy())"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "CoRxLon5MZ35"
},
"source": [
"## Use optical flow to align detections\n",
"\n",
"The two sets of bounding boxes above don't line up because of camera movement. \n",
"To see in more detail how tracks are aligned, initialize the tracker with the first image, and then run the optical flow step, `propagate_tracks`. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "wb_nkcPJJx2t"
},
"outputs": [],
"source": [
"def default_of_params():\n",
" its=20\n",
" eps=0.03\n",
" return {\n",
" 'winSize': (64,64),\n",
" 'maxLevel': 3,\n",
" 'criteria': (cv2.TermCriteria_COUNT + cv2.TermCriteria_EPS, its, eps)\n",
" }"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "mHVPymG8F2ke"
},
"outputs": [],
"source": [
"def propagate_detections(detections, image1, image2, of_params=None):\n",
" if of_params is None:\n",
" of_params = default_of_params()\n",
"\n",
" bboxes = [det.bbox for det in detections]\n",
" centers = np.float32([[bbox.center for bbox in bboxes]])\n",
" widths = np.float32([[bbox.width for bbox in bboxes]])\n",
" heights = np.float32([[bbox.height for bbox in bboxes]])\n",
"\n",
"\n",
" new_centers, status, error = cv2.calcOpticalFlowPyrLK(\n",
" image1, image2, centers, None, **of_params)\n",
"\n",
" x0s = new_centers[...,0] - widths/2\n",
" x1s = new_centers[...,0] + widths/2\n",
" y0s = new_centers[...,1] - heights/2\n",
" y1s = new_centers[...,1] + heights/2\n",
"\n",
" updated_detections = []\n",
" for i, det in enumerate(detections):\n",
" det = det.replace(\n",
" bbox = BBox(x0s[0,i], y0s[0,i], x1s[0,i], y1s[0,i]))\n",
" updated_detections.append(det)\n",
" return updated_detections"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "dCjgvoZnOcBu"
},
"source": [
"Now keep the white boxes for the initial detections, and the orange boxes for the new set of detections. But add the optical-flow propagated tracks in green. You can see that by using optical-flow to propagate the old detections to the new frame the alignment is quite good. It's this alignment between the old and new detections (between the green and orange boxes) that allows the tracker to make a persistent track for each COTS. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "aeTny8YnHwTw"
},
"outputs": [],
"source": [
"image = tf.io.read_file(filenames[example_frame_number])\n",
"image = tf.io.decode_jpeg(image).numpy()\n",
"image_gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)\n",
"\n",
"image2 = tf.io.read_file(filenames[example_frame_number+5]) # five frames later\n",
"image2 = tf.io.decode_jpeg(image2).numpy()\n",
"image2_gray = cv2.cvtColor(image2, cv2.COLOR_BGR2GRAY)\n",
"\n",
"updated_detections = propagate_detections(detections, image_gray, image2_gray)\n",
"\n",
"\n",
"for det in detections:\n",
" det.bbox.draw(image2, color=(255, 255, 255))\n",
"\n",
"for det in updated_detections:\n",
" det.bbox.draw(image2, color=(0, 255, 0))\n",
"\n",
"for det in detections2:\n",
" det.bbox.draw(image2, color=(255, 140, 0))\n",
"\n",
"PIL.Image.fromarray(image2)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "jbZ-7ICCENWG"
},
"source": [
"## Define **OpticalFlowTracker** class\n",
"\n",
"These help track the movement of each COTS object across the video frames.\n",
"\n",
"The tracker collects related detections into `Track` objects. \n",
"\n",
"The class's init is defined below, it's methods are defined in the following cells.\n",
"\n",
"The `__init__` method just initializes the track counter (`track_id`), and sets some default values for the tracking and optical flow configurations. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "3j2Ka1uGEoz4"
},
"outputs": [],
"source": [
"class OpticalFlowTracker:\n",
" \"\"\"Optical flow tracker.\"\"\"\n",
"\n",
" @classmethod\n",
" def add_method(cls, fun):\n",
" \"\"\"Attach a new method to the class.\"\"\"\n",
" setattr(cls, fun.__name__, fun)\n",
"\n",
"\n",
" def __init__(self, tid=1, ft=3.0, iou=0.5, tt=2.0, bb=32, of_params=None):\n",
" # Bookkeeping for the tracks.\n",
" # The running track count, incremented for each new track.\n",
" self.track_id = tid\n",
" self.tracks = []\n",
" self.prev_image = None\n",
" self.prev_time = None\n",
"\n",
" # Configuration for the track cleanup logic.\n",
" # How long to apply optical flow tracking without getting positive \n",
" # detections (sec).\n",
" self.track_flow_time = ft * 1000\n",
" # Required IoU overlap to link a detection to a track.\n",
" self.overlap_threshold = iou\n",
" # Used to detect if detector needs to be reset.\n",
" self.time_threshold = tt * 1000\n",
" self.border = bb\n",
"\n",
" if of_params is None:\n",
" of_params = default_of_params()\n",
" self.of_params = of_params\n"
] ]
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"id": "d0iALUwM0g2p" "id": "yBLSv0Fi_JJD"
}, },
"source": [ "source": [
"# Load the model and perform inference and tracking on sample data\n", "Internally the tracker will use small `Track` and `Tracklet` classes to organize the data. The `Tracklet` class is just a `Detection` with a timestamp, while a `Track` is a track ID, the most recent detection and a list of `Tracklet` objects forming the history of the track."
"Load trained model from disk and create the inference function `model_fn()`. This might take a little while."
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {
"id": "HXQnNjwl8Beu" "id": "gCQFfAkaY_WN"
}, },
"outputs": [], "outputs": [],
"source": [ "source": [
"absl_logging.set_verbosity(absl_logging.ERROR)\n", "@dataclasses.dataclass(frozen=True)\n",
"\n", "class Tracklet:\n",
"tf.config.optimizer.set_experimental_options({'auto_mixed_precision': True})\n", " timestamp:float\n",
"tf.config.optimizer.set_jit(True)\n", " detection:Detection\n",
"\n", "\n",
"model_fn = tf.saved_model.load(model_name).signatures['serving_default']" " def replace(self, **kwargs):\n",
" d = self.__dict__.copy()\n",
" d.update(kwargs)\n",
" return type(self)(**d)"
] ]
}, },
{ {
"cell_type": "markdown", "cell_type": "code",
"execution_count": null,
"metadata": { "metadata": {
"id": "jbZ-7ICCENWG" "id": "7qVW1a_YZBgL"
}, },
"outputs": [],
"source": [ "source": [
"# Define **OpticalFlowTracker** class and its related classes\n", "@dataclasses.dataclass(frozen=True)\n",
"class Track:\n",
" \"\"\"Tracker entries.\"\"\"\n",
" id:int\n",
" det: Detection\n",
" linked_dets:List[Tracklet] = dataclasses.field(default_factory=list)\n",
"\n", "\n",
"These help track the movement of each COTS object throughout the image frames." " def replace(self, **kwargs):\n",
" d = self.__dict__.copy()\n",
" d.update(kwargs)\n",
" return type(self)(**d)"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "markdown",
"execution_count": null,
"metadata": { "metadata": {
"id": "tybwY3eaY803" "id": "Ntl_4oUp_1nD"
}, },
"outputs": [],
"source": [ "source": [
"def box_area(x0, y0, x1, y1):\n", "The tracker keeps a list of active `Track` objects.\n",
" return (x1 - x0 + 1) * (y1 - y0 + 1)\n",
"\n",
"@dataclasses.dataclass\n",
"class Detection:\n",
" \"\"\"Detection dataclass.\"\"\"\n",
" class_id: int\n",
" score: float\n",
" x0: float\n",
" y0: float\n",
" x1: float\n",
" y1: float\n",
"\n", "\n",
" def __repr__(self):\n", "The main `update` method takes an image, along with the list of detections and the timestamp for that image. On each frame step it performs the following sub-tasks:\n",
" return (f'Class {self.class_id}, score {self.score}, '\n",
" f'box ({self.x0}, {self.y0}, {self.x1}, {self.y1})')\n",
"\n", "\n",
" def area(self):\n", "* The tracker uses optical flow to calculate where each `Track` expects to see a new `Detection`.\n",
" return box_area(self.x0, self.y0, self.x1, self.y1)\n", "* The tracker matches up the actual detections for the frame to the expected detections for each Track.\n",
"\n", "* If a detection doesn't get matched to an existing track, a new track is created for the detection.\n",
" def iou(self, other):\n", "* If a track stops getting assigned new detections, it is eventually deactivated. "
" overlap_x0 = max(self.x0, other.x0)\n",
" overlap_y0 = max(self.y0, other.y0)\n",
" overlap_x1 = min(self.x1, other.x1)\n",
" overlap_y1 = min(self.y1, other.y1)\n",
" if overlap_x0 < overlap_x1 and overlap_y0 < overlap_y1:\n",
" overlap_area = box_area(overlap_x0, overlap_y0, overlap_x1,\n",
" overlap_y1)\n",
" return overlap_area / (self.area() + other.area() - overlap_area)\n",
" else:\n",
" return 0\n"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {
"id": "gCQFfAkaY_WN" "id": "koZ0mjFTpiTv"
}, },
"outputs": [], "outputs": [],
"source": [ "source": [
"class Tracklet:\n", "@OpticalFlowTracker.add_method\n",
" def __init__(self, timestamp, detection):\n", "def update(self, image_bgr, detections, timestamp):\n",
" self.timestamp = timestamp\n", " start = time.time()\n",
" # Store a copy here to make sure the coordinates will not be updated\n", "\n",
" # when the optical flow propagation runs using another reference to this\n", " image = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2GRAY)\n",
" # detection.\n", "\n",
" self.detection = copy.deepcopy(detection)\n", " # Remove dead tracks.\n",
" self.tracks = self.cleanup_tracks(image, timestamp)\n",
"\n", "\n",
" def __repr__(self):\n", " # Run optical flow to update existing tracks.\n",
" return f'Time {self.timestamp}, ' + self.detection.__repr__()\n" " if self.prev_time is not None:\n",
" self.tracks = self.propagate_tracks(image)\n",
"\n",
" # Update the track list based on the new detections\n",
" self.apply_detections_to_tracks(image, detections, timestamp)\n",
"\n",
" self.prev_image = image\n",
" self.prev_time = timestamp\n",
"\n",
" return self.tracks"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "markdown",
"execution_count": null,
"metadata": { "metadata": {
"id": "7qVW1a_YZBgL" "id": "U-6__zF2CHFS"
}, },
"outputs": [],
"source": [ "source": [
"class Track:\n", "The `cleanup_tracks` method clears tracks that are too old or are too close to the edge of the image."
" \"\"\"Tracker entries.\"\"\"\n",
" def __init__(self, id, detection):\n",
" self.id = id\n",
" self.linked_dets = []\n",
" self.det = detection\n",
"\n",
" def __repr__(self):\n",
" result = f'Track {self.id}'\n",
" for linked_det in self.linked_dets:\n",
" result += '\\n' + linked_det.__repr__()\n",
" return result\n"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {
"id": "3j2Ka1uGEoz4" "id": "HQBj8GihjF3-"
}, },
"outputs": [], "outputs": [],
"source": [ "source": [
"class OpticalFlowTracker:\n", "@OpticalFlowTracker.add_method\n",
" \"\"\"Optical flow tracker.\"\"\"\n", "def cleanup_tracks(self, image, timestamp) -\u003e List[Track]:\n",
" def __init__(self, tid, ft=3.0, iou=0.5, tt=2.0, bb=32, size=64, its=20,\n",
" eps=0.03, levels=3):\n",
" self.track_id = tid\n",
" # How long to apply optical flow tracking without getting positive \n",
" # detections (sec).\n",
" self.track_flow_time = ft * 1000\n",
" # Required IoU overlap to link a detection to a track.\n",
" self.overlap_threshold = iou\n",
" # Used to detect if detector needs to be reset.\n",
" self.time_threshold = tt * 1000\n",
" self.border = bb\n",
" # Size of optical flow region.\n",
" self.of_size = (size, size)\n",
" self.of_criteria = (cv2.TermCriteria_COUNT + cv2.TermCriteria_EPS, its, \n",
" eps)\n",
" self.of_levels= levels\n",
"\n",
" self.tracks = []\n",
" self.prev_image = None\n",
" self.prev_time = -1\n",
"\n",
" def update(self, image_bgr, detections, timestamp):\n",
" start = time.time()\n",
" num_optical_flow_calls = 0\n",
"\n",
" image = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2GRAY)\n",
"\n",
" image_w = image.shape[1]\n", " image_w = image.shape[1]\n",
" image_h = image.shape[0]\n", " image_h = image.shape[0]\n",
"\n", "\n",
" # Assume tracker is invalid if too much time has passed!\n", " # Assume tracker is invalid if too much time has passed!\n",
" if (self.prev_time > 0 and\n", " if (self.prev_time is not None and\n",
" timestamp - self.prev_time > self.time_threshold):\n", " timestamp - self.prev_time \u003e self.time_threshold):\n",
" logging.info(\n", " logging.info(\n",
" 'Too much time since last update, resetting tracker.')\n", " 'Too much time since last update, resetting tracker.')\n",
" self.tracks = []\n", " return []\n",
"\n", "\n",
" # Remove tracks which are:\n", " # Remove tracks which are:\n",
" # - Touching the image edge.\n", " # - Touching the image edge.\n",
" # - Have existed for a long time without linking a real detection.\n", " # - Have existed for a long time without linking a real detection.\n",
" active_tracks = []\n", " active_tracks = []\n",
" for track in self.tracks:\n", " for track in self.tracks:\n",
" if (track.det.x0 < self.border or track.det.y0 < self.border or\n", " bbox = track.det.bbox\n",
" track.det.x1 >= (image_w - self.border) or\n", " if (bbox.x0 \u003c self.border or bbox.y0 \u003c self.border or\n",
" track.det.y1 >= (image_h - self.border)):\n", " bbox.x1 \u003e= (image_w - self.border) or\n",
" bbox.y1 \u003e= (image_h - self.border)):\n",
" logging.info(f'Removing track {track.id} because it\\'s near the border')\n", " logging.info(f'Removing track {track.id} because it\\'s near the border')\n",
" continue\n", " continue\n",
"\n", "\n",
" time_since_last_detection = timestamp - track.linked_dets[-1].timestamp\n", " time_since_last_detection = timestamp - track.linked_dets[-1].timestamp\n",
" if (time_since_last_detection > self.track_flow_time):\n", " if (time_since_last_detection \u003e self.track_flow_time):\n",
" logging.info(f'Removing track {track.id} because it\\'s too old '\n", " logging.info(f'Removing track {track.id} because it\\'s too old '\n",
" f'({time_since_last_detection:.02f}s)')\n", " f'({time_since_last_detection:.02f}s)')\n",
" continue\n", " continue\n",
"\n", "\n",
" active_tracks.append(track)\n", " active_tracks.append(track)\n",
"\n", "\n",
" self.tracks = active_tracks\n", " return active_tracks"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "DVzNcESxC6vY"
},
"source": [
"The `propagate_tracks` method uses optical flow to update each track's bounding box's position to predict their location in the new image: "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "0GycdAflCs6v"
},
"outputs": [],
"source": [
"@OpticalFlowTracker.add_method\n",
"def propagate_tracks(self, image):\n",
" if not self.tracks:\n",
" return self.tracks[:]\n",
"\n",
" detections = [track.det for track in self.tracks]\n",
" detections = propagate_detections(detections, self.prev_image, image, self.of_params)\n",
"\n", "\n",
" # Run optical flow to update existing tracks.\n", " return [track.replace(det=det) \n",
" if self.prev_time > 0:\n", " for track, det in zip(self.tracks, detections)]\n"
" # print('Running optical flow propagation.')\n", ]
" of_params = {\n", },
" 'winSize': self.of_size,\n", {
" 'maxLevel': self.of_levels,\n", "cell_type": "markdown",
" 'criteria': self.of_criteria\n", "metadata": {
" }\n", "id": "uLbVeetwD0ph"
" for track in self.tracks:\n", },
" input_points = np.float32([[[(track.det.x0 + track.det.x1) / 2,\n", "source": [
" (track.det.y0 + track.det.y1) / 2]]])\n", "The `apply_detections_to_tracks` method compares each detection to the updated bounding box for each track. The detection is added to the track that matches best, if the match is better than the `overlap_threshold`. If no track is better than the threshold, the detection is used to create a new track. \n",
" output_points, status, error = cv2.calcOpticalFlowPyrLK(\n",
" self.prev_image, image, input_points, None, **of_params)\n",
" num_optical_flow_calls += 1\n",
" w = track.det.x1 - track.det.x0\n",
" h = track.det.y1 - track.det.y0\n",
" # print(f'Detection before flow update: {track.det}')\n",
" track.det.x0 = output_points[0][0][0] - w * 0.5\n",
" track.det.y0 = output_points[0][0][1] - h * 0.5\n",
" track.det.x1 = output_points[0][0][0] + w * 0.5\n",
" track.det.y1 = output_points[0][0][1] + h * 0.5\n",
" # print(f'Detection after flow update: {track.det}')\n",
"\n", "\n",
"If a track has no new detection assigned to it the predicted detection is used."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "j6pfRhDRlApe"
},
"outputs": [],
"source": [
"@OpticalFlowTracker.add_method\n",
"def apply_detections_to_tracks(self, image, detections, timestamp):\n",
" image_w = image.shape[1]\n",
" image_h = image.shape[0]\n",
"\n", "\n",
" # Insert new detections.\n", " # Insert new detections.\n",
" detected_obj_track_ids = set()\n", " detected_obj_track_ids = set()\n",
"\n", "\n",
" for detection in detections:\n", " for detection in detections:\n",
" if (detection.x0 < self.border or detection.y0 < self.border or\n", " bbox = detection.bbox\n",
" detection.x1 >= image_w - self.border or\n", " if (bbox.x0 \u003c self.border or bbox.y0 \u003c self.border or\n",
" detection.y1 >= image_h - self.border):\n", " bbox.x1 \u003e= image_w - self.border or\n",
" # print('Skipping detection because it\\'s close to the border.')\n", " bbox.y1 \u003e= image_h - self.border):\n",
" logging.debug('Skipping detection because it\\'s close to the border.')\n",
" continue\n", " continue\n",
"\n", "\n",
" # See if detection can be linked to an existing track.\n", " # See if detection can be linked to an existing track.\n",
...@@ -443,18 +1086,18 @@ ...@@ -443,18 +1086,18 @@
" overlap_index = 0\n", " overlap_index = 0\n",
" overlap_max = -1000\n", " overlap_max = -1000\n",
" for track_index, track in enumerate(self.tracks):\n", " for track_index, track in enumerate(self.tracks):\n",
" # print(f'Testing track {track_index}')\n", " logging.debug('Testing track %d', track_index)\n",
" if track.det.class_id != detection.class_id:\n", " if track.det.class_id != detection.class_id:\n",
" continue\n", " continue\n",
" overlap = detection.iou(track.det)\n", " overlap = detection.bbox.iou(track.det.bbox)\n",
" if overlap > overlap_max:\n", " if overlap \u003e overlap_max:\n",
" overlap_index = track_index\n", " overlap_index = track_index\n",
" overlap_max = overlap\n", " overlap_max = overlap\n",
"\n", "\n",
" # Link to existing track with maximal IoU.\n", " # Link to existing track with maximal IoU.\n",
" if overlap_max > self.overlap_threshold:\n", " if overlap_max \u003e self.overlap_threshold:\n",
" track = self.tracks[overlap_index]\n", " track = self.tracks[overlap_index]\n",
" track.det = detection\n", " self.tracks[overlap_index] = track.replace(det=detection)\n",
" track.linked_dets.append(Tracklet(timestamp, detection))\n", " track.linked_dets.append(Tracklet(timestamp, detection))\n",
" detected_obj_track_ids.add(track.id)\n", " detected_obj_track_ids.add(track.id)\n",
" linked = True\n", " linked = True\n",
...@@ -471,264 +1114,275 @@ ...@@ -471,264 +1114,275 @@
" # If the detector does not find the obj but estimated in the tracker, \n", " # If the detector does not find the obj but estimated in the tracker, \n",
" # add the estimated one to that tracker's linked_dets\n", " # add the estimated one to that tracker's linked_dets\n",
" if track.id not in detected_obj_track_ids:\n", " if track.id not in detected_obj_track_ids:\n",
" track.linked_dets.append(Tracklet(timestamp, track.det))\n", " track.linked_dets.append(Tracklet(timestamp, track.det))"
"\n",
" self.prev_image = image\n",
" self.prev_time = timestamp\n",
"\n",
" if num_optical_flow_calls > 0:\n",
" tracking_ms = int(1000 * (time.time() - start))\n",
" logging.info(f'Tracking took {tracking_ms}ms, '\n",
" f'{num_optical_flow_calls} optical flow calls')\n",
"\n",
" return self.tracks"
] ]
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"id": "Kkg3SazB1edC" "id": "gY0AH-KUHPlC"
}, },
"source": [ "source": [
"Create a list of images to work on from the downloaded files." "## Test run the tracker\n",
"\n",
"So reload the test images, and run the detections to test out the tracker.\n",
"\n",
"On the first frame it creates and returns one track per detection:"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {
"id": "u0fqXQUzdZCu" "id": "7Ekkj_XFGdfq"
}, },
"outputs": [], "outputs": [],
"source": [ "source": [
"def embed_video_file(path: os.PathLike) -> display.HTML:\n", "example_frame_number = 52\n",
" \"\"\"Embeds a file in the notebook as an html tag with a data-url.\"\"\"\n", "image = tf.io.read_file(filenames[example_frame_number])\n",
" path = pathlib.Path(path)\n", "image = tf.io.decode_jpeg(image)\n",
" mime, unused_encoding = mimetypes.guess_type(str(path))\n", "result = model_fn(image[tf.newaxis, ...])\n",
" data = path.read_bytes()\n", "detections = Detection.process_model_output(image, result)\n",
"\n", "\n",
" b64 = base64.b64encode(data).decode()\n", "tracker = OpticalFlowTracker()\n",
" return display.HTML(\n", "tracks = tracker.update(image.numpy(), detections, timestamp = 0)\n",
" textwrap.dedent(\"\"\"\n", "\n",
" <video width=\"640\" height=\"480\" controls>\n", "print(f'detections : {len(detections)}') \n",
" <source src=\"data:{mime};base64,{b64}\" type=\"{mime}\">\n", "print(f'tracks : {len(tracks)}')"
" Your browser does not support the video tag.\n", ]
" </video>\n", },
" \"\"\").format(mime=mime, b64=b64))\n" {
"cell_type": "markdown",
"metadata": {
"id": "WovDYdNMII-n"
},
"source": [
"On the second frame many of the detections get assigned to existing tracks:"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {
"id": "kCdWsbO1afIJ" "id": "7iFEKwgMGi5n"
}, },
"outputs": [], "outputs": [],
"source": [ "source": [
"tmp_video_path = \"tmp_preview.mp4\"\n", "image2 = tf.io.read_file(filenames[example_frame_number+5]) # five frames later\n",
"image2 = tf.io.decode_jpeg(image2)\n",
"result2 = model_fn(image2[tf.newaxis, ...])\n",
"detections2 = Detection.process_model_output(image2, result2)\n",
"\n", "\n",
"filenames = sorted(glob.glob(f\"sample_images/{test_sequence_name}/*.jpg\"))\n", "new_tracks = tracker.update(image2.numpy(), detections2, timestamp = 1000)\n",
"img = cv2.imread(filenames[0])\n",
"height, width, layers = img.shape\n",
"size = (width, height)\n",
"\n", "\n",
"video_writer = cv2.VideoWriter(\n", "print(f'detections : {len(detections2)}') \n",
" filename=tmp_video_path,\n", "print(f'tracks : {len(new_tracks)}')"
" fourcc=cv2.VideoWriter_fourcc(*\"MP4V\"), \n",
" fps=15, \n",
" frameSize=size)\n",
" \n",
"for filename in tqdm(filenames):\n",
" img = cv2.imread(filename)\n",
" video_writer.write(img)\n",
"cv2.destroyAllWindows()\n",
"video_writer.release()"
] ]
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"id": "cHsKpPyviWmF" "id": "dbkedwiVrxnQ"
}, },
"source": [ "source": [
"Re-encode the video, and reduce its size (Colab crashes if you try to embed the full size video)." "Now the track IDs should be consistent:"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {
"id": "_li0qe-gh1iT" "id": "QexJR5gerw6q"
}, },
"outputs": [], "outputs": [],
"source": [ "source": [
"subprocess.check_call([\n", "test_img = image2.numpy()\n",
" \"ffmpeg\", \"-y\", \"-i\", tmp_video_path,\n", "for n,track in enumerate(tracks):\n",
" \"-vf\",\"scale=800:-1\",\n", " track.det.bbox.draw(test_img, label=n, color=(255, 255, 255))\n",
" \"-crf\", \"18\",\n", "\n",
" \"-preset\", \"veryfast\",\n", "for n,track in enumerate(new_tracks):\n",
" \"-vcodec\", \"libx264\", preview_video_path])" " track.det.bbox.draw(test_img, label=n, color=(255, 140, 0))\n",
"\n",
"PIL.Image.fromarray(test_img)"
] ]
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"id": "2ItoiHyYQGya" "id": "OW5gGixy1osE"
},
"source": [
"## Perform the COTS detection inference and tracking."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "f21596933d08"
}, },
"source": [ "source": [
"The images you downloaded are frames of a movie showing a top view of a coral reef with crown-of-thorns starfish. The movie looks like this:" "The main tracking loop will perform the following: \n",
"\n",
"1. Load the images in order.\n",
"2. Run the model on the image.\n",
"3. Update the tracker with the new images and detections.\n",
"4. Keep information about each track (id, current index and length) analysis or display. \n",
"\n",
"The `TrackAnnotation` class, below, will collect the data about each track:"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {
"colab": { "id": "lESJE0qXxubm"
"background_save": true
},
"id": "SiOsbr8xePkg"
}, },
"outputs": [], "outputs": [],
"source": [ "source": [
"embed_video_file(preview_video_path)" "@dataclasses.dataclass(frozen=True)\n",
"class TrackAnnotation:\n",
" det: Detection\n",
" seq_id: int\n",
" seq_idx: int\n",
" seq_length: Optional[int] = None\n",
"\n",
" def replace(self, **kwargs):\n",
" d = self.__dict__.copy()\n",
" d.update(kwargs)\n",
" return type(self)(**d)\n",
"\n",
" def annotation_str(self):\n",
" return f\"{self.seq_id} ({self.seq_idx}/{self.seq_length})\"\n"
] ]
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"id": "9Z0DTbWrZMZ-" "id": "3863fb28cd34"
},
"source": [
"The `parse_image` function, below, will take `(index, filename)` pairs load the images as tensors and return `(timestamp_ms, filename, image)` triples, assuming 30fps"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Dn7efhr0GBGz"
}, },
"outputs": [],
"source": [ "source": [
"The goal of the model is to put boxes around all of the starfish. Each starfish gets its own ID, and that ID will be stable as the camera passes over it." "# Read a jpg image and decode it to a uint8 tf tensor.\n",
"def parse_image(index, filename):\n",
" image = tf.io.read_file(filename)\n",
" image = tf.io.decode_jpeg(image)\n",
" timestamp_ms = 1000*index/30 # assuming 30fps\n",
" return (timestamp_ms, filename, image)"
] ]
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"id": "OW5gGixy1osE" "id": "8f878e4b0852"
}, },
"source": [ "source": [
"## Perform the COTS detection inference and tracking.\n", "Here is the main tracker loop. Note that initially the saved `TrackAnnotations` don't contain the track lengths. The lengths are collected in the `track_length_for_id` dict."
"\n",
"The detection inference has the following four main steps:\n",
"1. Read all images in the order of image indexes and convert them into uint8 TF tensors (Line 45-54).\n",
"2. Feed the TF image tensors into the model (Line 61) and get the detection output `detections`. In particular, the shape of input tensor is [batch size, height, width, number of channels]. In this demo project, the input shape is [4, 1080, 1920, 3].\n",
"3. The inference output `detections` contains four variables: `num_detections` (the number of detected objects), `detection_boxes` (the coordinates of each COTS object's bounding box), `detection_classes` (the class label of each detected object), `detection_scores` (the confidence score of each detected COTS object).\n",
"4. To track the movement of each detected object across frames, in each frame's detection, the tracker will estimate each tracked COTS object's position if COTS is not detected.\n"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {
"colab": { "id": "cqN8RGBgVbr4"
"background_save": true
},
"id": "vHIarsxH1svL"
}, },
"outputs": [], "outputs": [],
"source": [ "source": [
"# Record all the detected COTS objects with the scores equal to or greater than the threshold\n",
"threshold = 0.4\n",
"_CLASS_ID_TO_LABEL = ('COTS',)\n",
"# Create a tracker object\n", "# Create a tracker object\n",
"tracker = OpticalFlowTracker(tid=1)\n", "tracker = OpticalFlowTracker(tid=1)\n",
"# Record tracking responses from the tracker\n", "# Record tracking responses from the tracker\n",
"detection_result = []\n", "detection_result = []\n",
"# Record the length of each tracking sequence\n", "# Record the length of each tracking sequence\n",
"track_length_dict = {}\n", "track_length_for_id = {}\n",
"\n",
"base_time = tf.timestamp()\n",
"\n",
"# Format tracker response, and save it into a new object.\n",
"def format_tracker_response(file_path, tracks, seq_length_dict):\n",
" new_track_list = []\n",
" for track in tracks:\n",
" detection_columns = [\n",
" _CLASS_ID_TO_LABEL[track.det.class_id],\n",
" str(track.det.score),\n",
" str(track.id),\n",
" str(len(track.linked_dets)),\n",
" str(round(track.det.x0)),\n",
" str(round(track.det.y0)),\n",
" str(round(track.det.x1 - track.det.x0)),\n",
" str(round(track.det.y1 - track.det.y0))\n",
" ]\n",
"\n",
" if str(track.id) not in seq_length_dict:\n",
" seq_length_dict[str(track.id)] = len(track.linked_dets)\n",
" else:\n",
" if len(track.linked_dets) > seq_length_dict[str(track.id)]:\n",
" seq_length_dict[str(track.id)] = len(track.linked_dets)\n",
" new_track_list.append({\"score\":str(round(track.det.score, 3)), \"seq_id\": str(track.id), \"seq_idx\": str(len(track.linked_dets)),\n",
" \"x0\": round(track.det.x0), \"y0\": round(track.det.y0), \"x1\": round(track.det.x1), \"y1\": round(track.det.y1)})\n",
" return file_path, new_track_list, seq_length_dict\n",
"\n",
"# Read a jpg image and decode it to a uint8 tf tensor.\n",
"def parse_image(filename):\n",
" image = tf.io.read_file(filename)\n",
" image = tf.io.decode_jpeg(image)\n",
" return (tf.timestamp(), filename, image)\n",
"\n", "\n",
"# Create a data loader\n", "# Create a data loader\n",
"file_list = sorted(glob.glob(f\"sample_images/{test_sequence_name}/*.jpg\"))\n", "file_list = sorted(glob.glob(f\"sample_images/{test_sequence_name}/*.jpg\"))\n",
"list_ds = tf.data.Dataset.from_tensor_slices(file_list)\n", "list_ds = tf.data.Dataset.from_tensor_slices(file_list).enumerate()\n",
"images_ds = list_ds.map(parse_image)\n", "images_ds = list_ds.map(parse_image)\n",
"\n", "\n",
"# Traverse the dataset with batch size = 1, you cannot change the batch size\n", "# Traverse the dataset with batch size = 1, you cannot change the batch size\n",
"for data in tqdm(images_ds.batch(1, drop_remainder=True)):\n", "for timestamp_ms, file_path, images in tqdm(images_ds.batch(1, drop_remainder=True)):\n",
" # timestamp is used for recording the order of frames\n",
" timestamp, file_path, image = data\n",
" timestamp = (timestamp - base_time) * 1000\n",
" # get detection result\n", " # get detection result\n",
" detections = model_fn(image)\n", " detections = Detection.process_model_output(images[0], model_fn(images))\n",
" num_detections = detections['num_detections'].numpy().astype(np.int32)\n",
" detection_boxes = detections['detection_boxes'].numpy()\n",
" detection_classes = detections['detection_classes'].numpy().astype(np.int32)\n",
" detection_scores = detections['detection_scores'].numpy()\n",
"\n",
" batch_size, img_h, img_w = image.shape[0:3]\n",
"\n",
" for batch_index in range(batch_size):\n",
" valid_indices = detection_scores[batch_index, :] >= threshold\n",
" classes = detection_classes[batch_index, valid_indices]\n",
" scores = detection_scores[batch_index, valid_indices]\n",
" boxes = detection_boxes[batch_index, valid_indices, :]\n",
" detections = []\n",
"\n", "\n",
" for class_id, score, box in zip(classes, scores, boxes):\n",
" detections.append(\n",
" Detection(\n",
" class_id=class_id,\n",
" score=score,\n",
" x0=box[1] * img_w,\n",
" y0=box[0] * img_h,\n",
" x1=box[3] * img_w,\n",
" y1=box[2] * img_h,\n",
" ))\n",
" # Feed detection results and the corresponding timestamp to the tracker, and then get tracker response\n", " # Feed detection results and the corresponding timestamp to the tracker, and then get tracker response\n",
" tracks = tracker.update(image[batch_index].numpy(), detections, timestamp[batch_index])\n", " tracks = tracker.update(images[0].numpy(), detections, timestamp_ms[0])\n",
" base_file_path, track_list, track_length_dict = format_tracker_response(file_path[batch_index].numpy().decode(\"utf-8\"), tracks, track_length_dict)\n", " annotations = []\n",
" detection_result.append((base_file_path, track_list))" " for track in tracks:\n",
" anno = TrackAnnotation(\n",
" det=track.det,\n",
" seq_id = track.id,\n",
" seq_idx = len(track.linked_dets)\n",
" )\n",
" annotations.append(anno)\n",
" track_length_for_id[track.id] = len(track.linked_dets)\n",
" \n",
" detection_result.append((file_path.numpy()[0].decode(), annotations))"
] ]
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"id": "QkpmYRyFAMlM" "id": "29306d7f32df"
}, },
"source": [ "source": [
"# Output the detection results and play the result video\n", "Once the tracking loop has completed you can update the track length (`seq_length`) for each annotation from the `track_length_for_id` dict:"
"Once the inference is done, we use OpenCV to draw the bounding boxes (Line 9-10) and write the tracked COTS's information (Line 13-20: `COTS ID` `(sequence index/ sequence length)`) on each frame's image. Finally, we combine all frames into a video for visualisation."
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {
"colab": { "id": "oPSfnQ1o04Rx"
"background_save": true },
"outputs": [],
"source": [
"def update_annotation_lengths(detection_result, track_length_for_id):\n",
" new_result = []\n",
" for file_path, annotations in detection_result:\n",
" new_annotations = []\n",
" for anno in annotations:\n",
" anno = anno.replace(seq_length=track_length_for_id[anno.seq_id])\n",
" new_annotations.append(anno)\n",
" new_result.append((file_path, new_annotations))\n",
" return new_result"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "zda914lv1o_v"
},
"outputs": [],
"source": [
"detection_result = update_annotation_lengths(detection_result, track_length_for_id)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "QkpmYRyFAMlM"
},
"source": [
"## Output the detection results and play the result video\n",
"\n",
"Once the inference is done, we draw the bounding boxes and track information onto each frame's image. Finally, we combine all frames into a video for visualisation."
]
}, },
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "gWMJG7g95MGk" "id": "gWMJG7g95MGk"
}, },
"outputs": [], "outputs": [],
...@@ -740,26 +1394,10 @@ ...@@ -740,26 +1394,10 @@
" fps=15, \n", " fps=15, \n",
" frameSize=size)\n", " frameSize=size)\n",
"\n", "\n",
"for file_path, tracks in tqdm(detection_result):\n", "for file_path, annotations in tqdm(detection_result):\n",
" image = cv2.imread(file_path)\n", " image = cv2.imread(file_path)\n",
" for track in tracks:\n", " for anno in annotations:\n",
" # Draw the predicted bounding box\n", " anno.det.bbox.draw(image, label=anno.annotation_str(), color=(0, 140, 255))\n",
" cv2.rectangle(image, (track['x0'], track['y0']),\n",
" (track['x1'], track['y1']),\n",
" (0, 140, 255), thickness=2,)\n",
" # Write the tracked COTS ID, and its corresponding tracking index and tracking sequence length\n",
" cv2.putText(image, f\"{track['seq_id']}\", (track['x0'], track['y0']-10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, (0, 140, 255), 2)\n",
" if len(track[\"seq_id\"]) == 1:\n",
" offset = 20\n",
" elif len(track[\"seq_id\"]) == 2:\n",
" offset = 40\n",
" else:\n",
" offset = 60\n",
" cv2.putText(image, \n",
" f\"({track['seq_idx']}/{track_length_dict[track['seq_id']]})\",\n",
" (track['x0'] + offset, track['y0']-10),\n",
" cv2.FONT_HERSHEY_SIMPLEX,\n",
" 0.6, (0, 140, 255), 2)\n",
" detect_video_writer.write(image)\n", " detect_video_writer.write(image)\n",
"cv2.destroyAllWindows()\n", "cv2.destroyAllWindows()\n",
"\n", "\n",
...@@ -770,9 +1408,6 @@ ...@@ -770,9 +1408,6 @@
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {
"colab": {
"background_save": true
},
"id": "9s1myz67jcV8" "id": "9s1myz67jcV8"
}, },
"outputs": [], "outputs": [],
...@@ -789,9 +1424,6 @@ ...@@ -789,9 +1424,6 @@
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {
"colab": {
"background_save": true
},
"id": "wsK5cvX5jkL7" "id": "wsK5cvX5jkL7"
}, },
"outputs": [], "outputs": [],
...@@ -812,9 +1444,6 @@ ...@@ -812,9 +1444,6 @@
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {
"colab": {
"background_save": true
},
"id": "tyHucK8lbGXk" "id": "tyHucK8lbGXk"
}, },
"outputs": [], "outputs": [],
...@@ -831,18 +1460,14 @@ ...@@ -831,18 +1460,14 @@
"accelerator": "GPU", "accelerator": "GPU",
"colab": { "colab": {
"collapsed_sections": [], "collapsed_sections": [],
"name": "Crown-of-Thorns Starfish Detection Pipeline", "name": "crown_of_thorns_starfish_detection_pipeline.ipynb",
"provenance": [] "toc_visible": true
}, },
"kernelspec": { "kernelspec": {
"display_name": "Python 3", "display_name": "Python 3",
"name": "python3" "name": "python3"
},
"language_info": {
"name": "python"
} }
}, },
"nbformat": 4, "nbformat": 4,
"nbformat_minor": 0 "nbformat_minor": 0
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment