" <a target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/models/blob/master/official/projects/cots_detector/crown_of_thorns_starfish_detection_pipeline.ipynb\"><img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Run in Google Colab</a>\n",
" \u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/models/blob/master/official/projects/cots_detector/crown_of_thorns_starfish_detection_pipeline.ipynb?force_crab_mode=1\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e\n",
" </td>\n",
" \u003c/td\u003e\n",
" <td>\n",
" \u003ctd\u003e\n",
" <a target=\"_blank\" href=\"https://github.com/tensorflow/models/blob/master/official/projects/cots_detector/crown_of_thorns_starfish_detection_pipeline.ipynb\"><img src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" />View on GitHub</a>\n",
" \u003ca target=\"_blank\" href=\"https://github.com/tensorflow/models/blob/master/official/projects/cots_detector/crown_of_thorns_starfish_detection_pipeline.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView on GitHub\u003c/a\u003e\n",
" </td>\n",
" \u003c/td\u003e\n",
"</table>"
"\u003c/table\u003e"
]
]
},
},
{
{
"cell_type": "markdown",
"cell_type": "markdown",
"metadata": {
"metadata": {
"id": "GUQ1x137ysLD"
"id": "GUQ1x137ysLD"
},
},
"source": [
"source": [
"Coral reefs are some of the most diverse and important ecosystems in the world , however they face a number of rising threats that have resulted in massive global declines. In Australia, outbreaks of the coral-eating crown-of-thorns starfish (COTS) have been shown to cause major coral loss, with just 15 starfish in a hectare being able to strip a reef of 90% of its coral tissue. While COTS naturally exist in the Indo-Pacific, overfishing and excess run-off nutrients have led to massive outbreaks that are devastating already vulnerable coral communities.\n",
"Coral reefs are some of the most diverse and important ecosystems in the world , however they face a number of rising threats that have resulted in massive global declines. In Australia, outbreaks of the coral-eating crown-of-thorns starfish (COTS) have been shown to cause major coral loss, with just 15 starfish in a hectare being able to strip a reef of 90% of its coral tissue. While COTS naturally exist in the Indo-Pacific, overfishing and excess run-off nutrients have led to massive outbreaks that are devastating already vulnerable coral communities.\n",
"\n",
"\n",
"Controlling COTS populations is critical to promoting coral growth and resilience, so Google teamed up with Australia’s national science agency, [CSIRO](https://www.csiro.au/en/), to tackle this problem. We trained ML object detection models to help scale underwater surveys, enabling the monitoring and mapping out these harmful invertebrates with the ultimate goal of helping control teams to address and prioritize outbreaks."
"Controlling COTS populations is critical to promoting coral growth and resilience, so Google teamed up with Australia’s national science agency, [CSIRO](https://www.csiro.au/en/), to tackle this problem. We trained ML object detection models to help scale underwater surveys, enabling the monitoring and mapping out these harmful invertebrates with the ultimate goal of helping control teams to address and prioritize outbreaks."
]
]
},
},
{
{
"cell_type": "markdown",
"cell_type": "markdown",
"metadata": {
"metadata": {
"id": "jDiIX2xawkJw"
"id": "jDiIX2xawkJw"
},
},
"source": [
"source": [
"### This notebook\n",
"## About this notebook\n",
"\n",
"\n",
"This notebook tutorial shows how to detect COTS using a pre-trained COTS detector implemented in TensorFlow. On top of just running the model on each frame of the video, the tracking code in this notebook aligns detections from frame to frame creating a consistent track for each COTS. Each track is given an id and frame count. Here is an example image from a video of a reef showing labeled COTS starfish.\n",
"This notebook tutorial shows how to detect COTS using a pre-trained COTS detector implemented in TensorFlow. On top of just running the model on each frame of the video, the tracking code in this notebook aligns detections from frame to frame creating a consistent track for each COTS. Each track is given an id and frame count. Here is an example image from a video of a reef showing labeled COTS starfish.\n",
"It is recommended to enable GPU to accelerate the inference. On CPU, this runs for about 40 minutes, but on GPU it takes only 10 minutes. (from colab menu: *Runtime > Change runtime type > Hardware accelerator > select \"GPU\"*)."
"It is recommended to enable GPU to accelerate the inference. On CPU, this runs for about 40 minutes, but on GPU it takes only 10 minutes. (In Colab it should already be set to GPU in the Runtime menu: *Runtime \u003e Change runtime type \u003e Hardware accelerator \u003e select \"GPU\"*)."
]
]
},
},
{
{
"cell_type": "markdown",
"cell_type": "markdown",
"metadata": {
"metadata": {
"id": "a4R2T97u442o"
"id": "a4R2T97u442o"
},
},
"source": [
"source": [
"Install all needed packages."
"## Setup \n",
]
"\n",
},
"Install all needed packages."
{
]
"cell_type": "code",
},
"execution_count": null,
{
"metadata": {
"cell_type": "code",
"id": "5Gs7XvCGlwlj"
"execution_count": null,
},
"metadata": {
"outputs": [],
"id": "5Gs7XvCGlwlj"
"source": [
},
"# remove the existing datascience package to avoid package conflicts in the colab environment\n",
"outputs": [],
"!pip3 uninstall -y datascience\n",
"source": [
"!pip3 install -q opencv-python"
"# remove the existing datascience package to avoid package conflicts in the colab environment\n",
"Re-encode the video, and reduce its size (Colab crashes if you try to embed the full size video)."
"cell_type": "markdown",
]
"metadata": {
},
"id": "jbZ-7ICCENWG"
{
},
"cell_type": "code",
"source": [
"execution_count": null,
"# Define **OpticalFlowTracker** class and its related classes\n",
"metadata": {
"\n",
"id": "_li0qe-gh1iT"
"These help track the movement of each COTS object throughout the image frames."
},
]
"outputs": [],
},
"source": [
{
"subprocess.check_call([\n",
"cell_type": "code",
" \"ffmpeg\", \"-y\", \"-i\", tmp_video_path,\n",
"execution_count": null,
" \"-vf\",\"scale=800:-1\",\n",
"metadata": {
" \"-crf\", \"18\",\n",
"id": "tybwY3eaY803"
" \"-preset\", \"veryfast\",\n",
},
" \"-vcodec\", \"libx264\", preview_video_path])"
"outputs": [],
]
"source": [
},
"def box_area(x0, y0, x1, y1):\n",
{
" return (x1 - x0 + 1) * (y1 - y0 + 1)\n",
"cell_type": "markdown",
"\n",
"metadata": {
"@dataclasses.dataclass\n",
"id": "2ItoiHyYQGya"
"class Detection:\n",
},
" \"\"\"Detection dataclass.\"\"\"\n",
"source": [
" class_id: int\n",
"The images you downloaded are frames of a movie showing a top view of a coral reef with crown-of-thorns starfish. Use the `base64` data-URL trick to embed the video in this notebook:"
"Can you se them? there are lots. The goal of the model is to put boxes around all of the starfish. Each starfish will get its own ID, and that ID will be stable as the camera passes over it."
]
]
},
},
{
{
"cell_type": "code",
"cell_type": "markdown",
"execution_count": null,
"metadata": {
"metadata": {
"id": "d0iALUwM0g2p"
"id": "7qVW1a_YZBgL"
},
},
"source": [
"outputs": [],
"## Load the model"
"source": [
]
"class Track:\n",
},
" \"\"\"Tracker entries.\"\"\"\n",
{
" def __init__(self, id, detection):\n",
"cell_type": "markdown",
" self.id = id\n",
"metadata": {
" self.linked_dets = []\n",
"id": "fVq6vNBTxM62"
" self.det = detection\n",
},
"\n",
"source": [
" def __repr__(self):\n",
"Download the trained COTS detection model that matches your preferences from earlier."
"The images you downloaded are frames of a movie showing a top view of a coral reef with crown-of-thorns starfish. The movie looks like this:"
" cv2.rectangle(image, \n",
]
" (int(self.x0), int(self.y0)),\n",
},
" (int(self.x1), int(self.y1)),\n",
{
" color,\n",
"cell_type": "code",
" thickness=2)\n",
"execution_count": null,
" if label is not None:\n",
"metadata": {
" cv2.putText(image, str(label), \n",
"colab": {
" (int(self.x0), int(self.y0-10)),\n",
"background_save": true
" cv2.FONT_HERSHEY_SIMPLEX,\n",
},
" 0.9, color, thickness=2)\n",
"id": "SiOsbr8xePkg"
" return image"
},
]
"outputs": [],
},
"source": [
{
"embed_video_file(preview_video_path)"
"cell_type": "markdown",
]
"metadata": {
},
"id": "2izYMR9Q6Dn0"
{
},
"cell_type": "markdown",
"source": [
"metadata": {
"And a class to represent a `Detection`, with a method to create a list of detections from the model's output:"
"id": "9Z0DTbWrZMZ-"
]
},
},
"source": [
{
"The goal of the model is to put boxes around all of the starfish. Each starfish gets its own ID, and that ID will be stable as the camera passes over it."
"cell_type": "code",
]
"execution_count": null,
},
"metadata": {
{
"id": "tybwY3eaY803"
"cell_type": "markdown",
},
"metadata": {
"outputs": [],
"id": "OW5gGixy1osE"
"source": [
},
"@dataclasses.dataclass(frozen=True)\n",
"source": [
"class Detection:\n",
"## Perform the COTS detection inference and tracking.\n",
" \"\"\"Detection dataclass.\"\"\"\n",
"\n",
" class_id: int\n",
"The detection inference has the following four main steps:\n",
" score: float\n",
"1. Read all images in the order of image indexes and convert them into uint8 TF tensors (Line 45-54).\n",
" bbox: BBox\n",
"2. Feed the TF image tensors into the model (Line 61) and get the detection output `detections`. In particular, the shape of input tensor is [batch size, height, width, number of channels]. In this demo project, the input shape is [4, 1080, 1920, 3].\n",
" threshold:float = 0.4\n",
"3. The inference output `detections` contains four variables: `num_detections` (the number of detected objects), `detection_boxes` (the coordinates of each COTS object's bounding box), `detection_classes` (the class label of each detected object), `detection_scores` (the confidence score of each detected COTS object).\n",
"\n",
"4. To track the movement of each detected object across frames, in each frame's detection, the tracker will estimate each tracked COTS object's position if COTS is not detected.\n"
"That works well for one frame, but to count the number of COTS in a video you'll need to track the detections from frame to frame. The raw detection indices are not stable, they're just sorted by the detection score. Below both sets of detections are overlaid on the second image with the first frame's detections in white and the second frame's in orange, the indices are not aligned. The positions are shifted because of camera motion between the two frames:"
"The two sets of bounding boxes above don't line up because of camera movement. \n",
"# Output the detection results and play the result video\n",
"To see in more detail how tracks are aligned, initialize the tracker with the first image, and then run the optical flow step, `propagate_tracks`. "
"Once the inference is done, we use OpenCV to draw the bounding boxes (Line 9-10) and write the tracked COTS's information (Line 13-20: `COTS ID` `(sequence index/ sequence length)`) on each frame's image. Finally, we combine all frames into a video for visualisation."
"Now keep the white boxes for the initial detections, and the orange boxes for the new set of detections. But add the optical-flow propagated tracks in green. You can see that by using optical-flow to propagate the old detections to the new frame the alignment is quite good. It's this alignment between the old and new detections (between the green and orange boxes) that allows the tracker to make a persistent track for each COTS. "
"The output video is now saved as movie at `detection_full_video_path`. You can download your video by uncommenting the following code."
" det.bbox.draw(image2, color=(0, 255, 0))\n",
]
"\n",
},
"for det in detections2:\n",
{
" det.bbox.draw(image2, color=(255, 140, 0))\n",
"cell_type": "code",
"\n",
"execution_count": null,
"PIL.Image.fromarray(image2)"
"metadata": {
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "jbZ-7ICCENWG"
},
"source": [
"## Define **OpticalFlowTracker** class\n",
"\n",
"These help track the movement of each COTS object across the video frames.\n",
"\n",
"The tracker collects related detections into `Track` objects. \n",
"\n",
"The class's init is defined below, it's methods are defined in the following cells.\n",
"\n",
"The `__init__` method just initializes the track counter (`track_id`), and sets some default values for the tracking and optical flow configurations. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "3j2Ka1uGEoz4"
},
"outputs": [],
"source": [
"class OpticalFlowTracker:\n",
" \"\"\"Optical flow tracker.\"\"\"\n",
"\n",
" @classmethod\n",
" def add_method(cls, fun):\n",
" \"\"\"Attach a new method to the class.\"\"\"\n",
" # The running track count, incremented for each new track.\n",
" self.track_id = tid\n",
" self.tracks = []\n",
" self.prev_image = None\n",
" self.prev_time = None\n",
"\n",
" # Configuration for the track cleanup logic.\n",
" # How long to apply optical flow tracking without getting positive \n",
" # detections (sec).\n",
" self.track_flow_time = ft * 1000\n",
" # Required IoU overlap to link a detection to a track.\n",
" self.overlap_threshold = iou\n",
" # Used to detect if detector needs to be reset.\n",
" self.time_threshold = tt * 1000\n",
" self.border = bb\n",
"\n",
" if of_params is None:\n",
" of_params = default_of_params()\n",
" self.of_params = of_params\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "yBLSv0Fi_JJD"
},
"source": [
"Internally the tracker will use small `Track` and `Tracklet` classes to organize the data. The `Tracklet` class is just a `Detection` with a timestamp, while a `Track` is a track ID, the most recent detection and a list of `Tracklet` objects forming the history of the track."
"The tracker keeps a list of active `Track` objects.\n",
"\n",
"The main `update` method takes an image, along with the list of detections and the timestamp for that image. On each frame step it performs the following sub-tasks:\n",
"\n",
"* The tracker uses optical flow to calculate where each `Track` expects to see a new `Detection`.\n",
"* The tracker matches up the actual detections for the frame to the expected detections for each Track.\n",
"* If a detection doesn't get matched to an existing track, a new track is created for the detection.\n",
"* If a track stops getting assigned new detections, it is eventually deactivated. "
" for track, det in zip(self.tracks, detections)]\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "uLbVeetwD0ph"
},
"source": [
"The `apply_detections_to_tracks` method compares each detection to the updated bounding box for each track. The detection is added to the track that matches best, if the match is better than the `overlap_threshold`. If no track is better than the threshold, the detection is used to create a new track. \n",
"\n",
"If a track has no new detection assigned to it the predicted detection is used."
"The `parse_image` function, below, will take `(index, filename)` pairs load the images as tensors and return `(timestamp_ms, filename, image)` triples, assuming 30fps"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Dn7efhr0GBGz"
},
"outputs": [],
"source": [
"# Read a jpg image and decode it to a uint8 tf tensor.\n",
"Here is the main tracker loop. Note that initially the saved `TrackAnnotations` don't contain the track lengths. The lengths are collected in the `track_length_for_id` dict."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "cqN8RGBgVbr4"
},
"outputs": [],
"source": [
"# Create a tracker object\n",
"tracker = OpticalFlowTracker(tid=1)\n",
"# Record tracking responses from the tracker\n",
"detection_result = []\n",
"# Record the length of each tracking sequence\n",
"## Output the detection results and play the result video\n",
"\n",
"Once the inference is done, we draw the bounding boxes and track information onto each frame's image. Finally, we combine all frames into a video for visualisation."