Commit 205964c7 authored by Dan Kondratyuk's avatar Dan Kondratyuk Committed by A. Unique TensorFlower
Browse files

Add notebook to plot MoViNet video stream predictions.

PiperOrigin-RevId: 434767967
parent 2ac5a5c0
...@@ -8,12 +8,17 @@ This repository is the official implementation of ...@@ -8,12 +8,17 @@ This repository is the official implementation of
[MoViNets: Mobile Video Networks for Efficient Video [MoViNets: Mobile Video Networks for Efficient Video
Recognition](https://arxiv.org/abs/2103.11511). Recognition](https://arxiv.org/abs/2103.11511).
**[UPDATE 2021-07-12] Mobile Models Available via [TF Lite](#tf-lite-streaming-models)** - **[UPDATE 2022-03-14] Quantized TF Lite models
[available on TF Hub](https://tfhub.dev/s?deployment-format=lite&q=movinet)
(also [see table](https://tfhub.dev/google/collections/movinet) for
quantized performance)**
<p align="center"> <p align="center">
<img src="https://storage.googleapis.com/tf_model_garden/vision/movinet/artifacts/hoverboard_stream.gif" height=500> <img src="https://storage.googleapis.com/tf_model_garden/vision/movinet/artifacts/hoverboard_stream.gif" height=500>
</p> </p>
Create your own video plot like the one above with this [Colab notebook](https://colab.research.google.com/github/tensorflow/models/blob/master/official/projects/movinet/plot_movinet_video_stream_predictions.ipynb).
## Description ## Description
Mobile Video Networks (MoViNets) are efficient video classification models Mobile Video Networks (MoViNets) are efficient video classification models
...@@ -55,6 +60,8 @@ approach that performs redundant computation and limits temporal scope. ...@@ -55,6 +60,8 @@ approach that performs redundant computation and limits temporal scope.
## History ## History
- **2022-03-14** Support quantized TF Lite models and add/update Colab
notebooks.
- **2021-07-12** Add TF Lite support and replace 3D stream models with - **2021-07-12** Add TF Lite support and replace 3D stream models with
mobile-friendly (2+1)D stream. mobile-friendly (2+1)D stream.
- **2021-05-30** Add streaming MoViNet checkpoints and examples. - **2021-05-30** Add streaming MoViNet checkpoints and examples.
...@@ -165,7 +172,7 @@ different architecture. To download the old checkpoints, insert `_legacy` before ...@@ -165,7 +172,7 @@ different architecture. To download the old checkpoints, insert `_legacy` before
For convenience, we provide converted TF Lite models for inference on mobile For convenience, we provide converted TF Lite models for inference on mobile
devices. See the [TF Lite Example](#tf-lite-example) to export and run your own devices. See the [TF Lite Example](#tf-lite-example) to export and run your own
models. models. We also provide [quantized TF Lite binaries via TF Hub](https://tfhub.dev/s?deployment-format=lite&q=movinet).
For reference, MoViNet-A0-Stream runs with a similar latency to For reference, MoViNet-A0-Stream runs with a similar latency to
[MobileNetV3-Large] [MobileNetV3-Large]
...@@ -226,7 +233,7 @@ backbone = movinet.Movinet( ...@@ -226,7 +233,7 @@ backbone = movinet.Movinet(
use_external_states=False, use_external_states=False,
) )
model = movinet_model.MovinetClassifier( model = movinet_model.MovinetClassifier(
backbone, num_classes=600, output_states=True) backbone, num_classes=600, output_states=False)
# Create your example input here. # Create your example input here.
# Refer to the paper for recommended input shapes. # Refer to the paper for recommended input shapes.
......
...@@ -105,6 +105,7 @@ ...@@ -105,6 +105,7 @@
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {
"cellView": "form",
"id": "dx55NK3ZoZeh" "id": "dx55NK3ZoZeh"
}, },
"outputs": [], "outputs": [],
......
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "qwBHHt-XvPqn"
},
"source": [
"# Plot MoViNet Video Stream Predictions\n",
"\n",
"This notebook uses [MoViNets (Mobile Video Networks)](https://github.com/tensorflow/models/tree/master/official/projects/movinet) to predict a human action in a streaming video and outputs a visualization of predictions on each frame.\n",
"\n",
"Provide a video URL or upload your own to see how predictions change over time. All models can be run on CPU.\n",
"\n",
"Pretrained models are provided by [TensorFlow Hub](https://tfhub.dev/google/collections/movinet/) and the [TensorFlow Model Garden](https://github.com/tensorflow/models/tree/master/official/projects/movinet), trained on [Kinetics 600](https://deepmind.com/research/open-source/kinetics) for video action classification. All Models use TensorFlow 2 with Keras for inference and training. See the [research paper](https://arxiv.org/pdf/2103.11511.pdf) for more details.\n",
"\n",
"Example output using [this gif](https://github.com/tensorflow/models/raw/f8af2291cced43fc9f1d9b41ddbf772ae7b0d7d2/official/projects/movinet/files/jumpingjack.gif) as input:\n",
"\n",
"![jumping jacks plot](https://storage.googleapis.com/tf_model_garden/vision/movinet/artifacts/jumpingjacks_plot.gif)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "ElvELd9mIfZe"
},
"outputs": [],
"source": [
"#@title Run this cell to initialize and setup a [MoViNet](https://github.com/tensorflow/models/tree/master/official/projects/movinet) model.\n",
"\n",
"\n",
"# Install the mediapy package for visualizing images/videos.\n",
"# See https://github.com/google/mediapy\n",
"!pip install -q mediapy\n",
"\n",
"# Run imports\n",
"import os\n",
"import io\n",
"\n",
"import matplotlib as mpl\n",
"import matplotlib.pyplot as plt\n",
"import mediapy as media\n",
"import numpy as np\n",
"import PIL\n",
"import pandas as pd\n",
"import tensorflow as tf\n",
"import tensorflow_datasets as tfds\n",
"import tensorflow_hub as hub\n",
"import tqdm\n",
"from google.colab import files\n",
"import urllib.request\n",
"\n",
"mpl.rcParams.update({\n",
" 'font.size': 10,\n",
"})\n",
"\n",
"\n",
"# Download Kinetics 600 label map\n",
"!wget https://raw.githubusercontent.com/tensorflow/models/f8af2291cced43fc9f1d9b41ddbf772ae7b0d7d2/official/projects/movinet/files/kinetics_600_labels.txt -O labels.txt -q\n",
"\n",
"with tf.io.gfile.GFile('labels.txt') as f:\n",
" lines = f.readlines()\n",
" KINETICS_600_LABELS_LIST = [line.strip() for line in lines]\n",
" KINETICS_600_LABELS = tf.constant(KINETICS_600_LABELS_LIST)\n",
"\n",
"def get_top_k(probs, k=5, label_map=KINETICS_600_LABELS):\n",
" \"\"\"Outputs the top k model labels and probabilities on the given video.\"\"\"\n",
" top_predictions = tf.argsort(probs, axis=-1, direction='DESCENDING')[:k]\n",
" top_labels = tf.gather(label_map, top_predictions, axis=-1)\n",
" top_labels = [label.decode('utf8') for label in top_labels.numpy()]\n",
" top_probs = tf.gather(probs, top_predictions, axis=-1).numpy()\n",
" return tuple(zip(top_labels, top_probs))\n",
"\n",
"def predict_top_k(model, video, k=5, label_map=KINETICS_600_LABELS):\n",
" \"\"\"Outputs the top k model labels and probabilities on the given video.\"\"\"\n",
" outputs = model.predict(video[tf.newaxis])[0]\n",
" probs = tf.nn.softmax(outputs)\n",
" return get_top_k(probs, k=k, label_map=label_map)\n",
"\n",
"def load_movinet_from_hub(model_id, model_mode, hub_version=3):\n",
" \"\"\"Loads a MoViNet model from TF Hub.\"\"\"\n",
" hub_url = f'https://tfhub.dev/tensorflow/movinet/{model_id}/{model_mode}/kinetics-600/classification/{hub_version}'\n",
"\n",
" encoder = hub.KerasLayer(hub_url, trainable=True)\n",
"\n",
" inputs = tf.keras.layers.Input(\n",
" shape=[None, None, None, 3],\n",
" dtype=tf.float32)\n",
"\n",
" if model_mode == 'base':\n",
" inputs = dict(image=inputs)\n",
" else:\n",
" # Define the state inputs, which is a dict that maps state names to tensors.\n",
" init_states_fn = encoder.resolved_object.signatures['init_states']\n",
" state_shapes = {\n",
" name: ([s if s \u003e 0 else None for s in state.shape], state.dtype)\n",
" for name, state in init_states_fn(tf.constant([0, 0, 0, 0, 3])).items()\n",
" }\n",
" states_input = {\n",
" name: tf.keras.Input(shape[1:], dtype=dtype, name=name)\n",
" for name, (shape, dtype) in state_shapes.items()\n",
" }\n",
"\n",
" # The inputs to the model are the states and the video\n",
" inputs = {**states_input, 'image': inputs}\n",
"\n",
" # Output shape: [batch_size, 600]\n",
" outputs = encoder(inputs)\n",
"\n",
" model = tf.keras.Model(inputs, outputs)\n",
" model.build([1, 1, 1, 1, 3])\n",
"\n",
" return model\n",
"\n",
"# Download example gif\n",
"!wget https://github.com/tensorflow/models/raw/f8af2291cced43fc9f1d9b41ddbf772ae7b0d7d2/official/projects/movinet/files/jumpingjack.gif -O jumpingjack.gif -q\n",
"\n",
"def load_gif(file_path, image_size=(224, 224)):\n",
" \"\"\"Loads a gif file into a TF tensor.\"\"\"\n",
" with tf.io.gfile.GFile(file_path, 'rb') as f:\n",
" video = tf.io.decode_gif(f.read())\n",
" video = tf.image.resize(video, image_size)\n",
" video = tf.cast(video, tf.float32) / 255.\n",
" return video\n",
"\n",
"def get_top_k_streaming_labels(probs, k=5, label_map=KINETICS_600_LABELS_LIST):\n",
" \"\"\"Returns the top-k labels over an entire video sequence.\n",
"\n",
" Args:\n",
" probs: probability tensor of shape (num_frames, num_classes) that represents\n",
" the probability of each class on each frame.\n",
" k: the number of top predictions to select.\n",
" label_map: a list of labels to map logit indices to label strings.\n",
"\n",
" Returns:\n",
" a tuple of the top-k probabilities, labels, and logit indices\n",
" \"\"\"\n",
" top_categories_last = tf.argsort(probs, -1, 'DESCENDING')[-1, :1]\n",
" categories = tf.argsort(probs, -1, 'DESCENDING')[:, :k]\n",
" categories = tf.reshape(categories, [-1])\n",
"\n",
" counts = sorted([\n",
" (i.numpy(), tf.reduce_sum(tf.cast(categories == i, tf.int32)).numpy())\n",
" for i in tf.unique(categories)[0]\n",
" ], key=lambda x: x[1], reverse=True)\n",
"\n",
" top_probs_idx = tf.constant([i for i, _ in counts[:k]])\n",
" top_probs_idx = tf.concat([top_categories_last, top_probs_idx], 0)\n",
" top_probs_idx = tf.unique(top_probs_idx)[0][:k+1]\n",
"\n",
" top_probs = tf.gather(probs, top_probs_idx, axis=-1)\n",
" top_probs = tf.transpose(top_probs, perm=(1, 0))\n",
" top_labels = tf.gather(label_map, top_probs_idx, axis=0)\n",
" top_labels = [label.decode('utf8') for label in top_labels.numpy()]\n",
"\n",
" return top_probs, top_labels, top_probs_idx\n",
"\n",
"def plot_streaming_top_preds_at_step(\n",
" top_probs,\n",
" top_labels,\n",
" step=None,\n",
" image=None,\n",
" legend_loc='lower left',\n",
" duration_seconds=10,\n",
" figure_height=500,\n",
" playhead_scale=0.8,\n",
" grid_alpha=0.3):\n",
" \"\"\"Generates a plot of the top video model predictions at a given time step.\n",
"\n",
" Args:\n",
" top_probs: a tensor of shape (k, num_frames) representing the top-k\n",
" probabilities over all frames.\n",
" top_labels: a list of length k that represents the top-k label strings.\n",
" step: the current time step in the range [0, num_frames].\n",
" image: the image frame to display at the current time step.\n",
" legend_loc: the placement location of the legend.\n",
" duration_seconds: the total duration of the video.\n",
" figure_height: the output figure height.\n",
" playhead_scale: scale value for the playhead.\n",
" grid_alpha: alpha value for the gridlines.\n",
"\n",
" Returns:\n",
" A tuple of the output numpy image, figure, and axes.\n",
" \"\"\"\n",
" num_labels, num_frames = top_probs.shape\n",
" if step is None:\n",
" step = num_frames\n",
"\n",
" fig = plt.figure(figsize=(6.5, 7), dpi=300)\n",
" gs = mpl.gridspec.GridSpec(8, 1)\n",
" ax2 = plt.subplot(gs[:-3, :])\n",
" ax = plt.subplot(gs[-3:, :])\n",
"\n",
" if image is not None:\n",
" ax2.imshow(image, interpolation='nearest')\n",
" ax2.axis('off')\n",
"\n",
" preview_line_x = tf.linspace(0., duration_seconds, num_frames)\n",
" preview_line_y = top_probs\n",
"\n",
" line_x = preview_line_x[:step+1]\n",
" line_y = preview_line_y[:, :step+1]\n",
"\n",
" for i in range(num_labels):\n",
" ax.plot(preview_line_x, preview_line_y[i], label=None, linewidth='1.5',\n",
" linestyle=':', color='gray')\n",
" ax.plot(line_x, line_y[i], label=top_labels[i], linewidth='2.0')\n",
"\n",
"\n",
" ax.grid(which='major', linestyle=':', linewidth='1.0', alpha=grid_alpha)\n",
" ax.grid(which='minor', linestyle=':', linewidth='0.5', alpha=grid_alpha)\n",
"\n",
" min_height = tf.reduce_min(top_probs) * playhead_scale\n",
" max_height = tf.reduce_max(top_probs)\n",
" ax.vlines(preview_line_x[step], min_height, max_height, colors='red')\n",
" ax.scatter(preview_line_x[step], max_height, color='red')\n",
"\n",
" ax.legend(loc=legend_loc)\n",
"\n",
" plt.xlim(0, duration_seconds)\n",
" plt.ylabel('Probability')\n",
" plt.xlabel('Time (s)')\n",
" plt.yscale('log')\n",
"\n",
" fig.tight_layout()\n",
" fig.canvas.draw()\n",
"\n",
" data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)\n",
" data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))\n",
" plt.close()\n",
"\n",
" figure_width = int(figure_height * data.shape[1] / data.shape[0])\n",
" image = PIL.Image.fromarray(data).resize([figure_width, figure_height])\n",
" image = np.array(image)\n",
"\n",
" return image, (fig, ax, ax2)\n",
"\n",
"def plot_streaming_top_preds(\n",
" probs,\n",
" video,\n",
" top_k=5,\n",
" video_fps=25.,\n",
" figure_height=500,\n",
" use_progbar=True):\n",
" \"\"\"Generates a video plot of the top video model predictions.\n",
"\n",
" Args:\n",
" probs: probability tensor of shape (num_frames, num_classes) that represents\n",
" the probability of each class on each frame.\n",
" video: the video to display in the plot.\n",
" top_k: the number of top predictions to select.\n",
" video_fps: the input video fps.\n",
" figure_fps: the output video fps.\n",
" figure_height: the height of the output video.\n",
" use_progbar: display a progress bar.\n",
"\n",
" Returns:\n",
" A numpy array representing the output video.\n",
" \"\"\"\n",
" video_fps = 8.\n",
" figure_height = 500\n",
" steps = video.shape[0]\n",
" duration = steps / video_fps\n",
"\n",
" top_probs, top_labels, _ = get_top_k_streaming_labels(probs, k=top_k)\n",
"\n",
" images = []\n",
" step_generator = tqdm.trange(steps) if use_progbar else range(steps)\n",
" for i in step_generator:\n",
" image, _ = plot_streaming_top_preds_at_step(\n",
" top_probs=top_probs,\n",
" top_labels=top_labels,\n",
" step=i,\n",
" image=video[i],\n",
" duration_seconds=duration,\n",
" figure_height=figure_height,\n",
" )\n",
" images.append(image)\n",
"\n",
" return np.array(images)\n",
"\n",
"def generate_plot(\n",
" model,\n",
" video_url=None,\n",
" resolution=224,\n",
" video_fps=25,\n",
" display_fps=25):\n",
" # Load the video\n",
" if not video_url:\n",
" video_bytes = list(files.upload().values())[0]\n",
" with open('video', 'wb') as f:\n",
" f.write(video_bytes)\n",
" else:\n",
" urllib.request.urlretrieve(video_url, \"video\")\n",
"\n",
" video = tf.cast(media.read_video('video'), tf.float32) / 255.\n",
" video = tf.image.resize(video, [resolution, resolution], preserve_aspect_ratio=True)\n",
"\n",
" # Create initial states for the stream model\n",
" init_states_fn = model.layers[-1].resolved_object.signatures['init_states']\n",
" init_states = init_states_fn(tf.shape(video[tf.newaxis]))\n",
"\n",
" clips = tf.split(video[tf.newaxis], video.shape[0], axis=1)\n",
"\n",
" all_logits = []\n",
"\n",
" print('Running the model on the video...')\n",
"\n",
" # To run on a video, pass in one frame at a time\n",
" states = init_states\n",
" for clip in tqdm.tqdm(clips):\n",
" # Input shape: [1, 1, 172, 172, 3]\n",
" logits, states = model.predict({**states, 'image': clip}, verbose=0)\n",
" all_logits.append(logits)\n",
"\n",
" logits = tf.concat(all_logits, 0)\n",
" probs = tf.nn.softmax(logits)\n",
"\n",
" print('Generating the plot...')\n",
"\n",
" # Generate a plot and output to a video tensor\n",
" plot_video = plot_streaming_top_preds(probs, video, video_fps=video_fps)\n",
" media.show_video(plot_video, fps=display_fps, codec='gif')\n",
"\n",
"model_size = 'm' #@param [\"xs\", \"s\", \"m\", \"l\", \"xl\", \"xxl\"]\n",
"\n",
"model_map = {\n",
" 'xs': 'a0',\n",
" 's': 'a1',\n",
" 'm': 'a2',\n",
" 'l': 'a3',\n",
" 'xl': 'a4',\n",
" 'xxl': 'a5',\n",
"}\n",
"movinet_model_id = model_map[model_size]\n",
"\n",
"model = load_movinet_from_hub(\n",
" movinet_model_id, 'stream', hub_version=3)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "jO6HrPk8pqo8"
},
"outputs": [],
"source": [
"#@title Generate a video plot.\n",
"\n",
"#@markdown You may add a video URL (gif or mp4) or leave the video_url field blank to upload your own file.\n",
"video_url = \"https://i.pinimg.com/originals/33/5e/31/335e31bc8ed52511da0cfb4bc44e95c7.gif\" #@param {type:\"string\"}\n",
"\n",
"#@markdown The base input resolution to the model. A good value is 224, but can change based on model size.\n",
"resolution = 224 #@param\n",
"#@markdown The fps of the input video.\n",
"video_fps = 12 #@param\n",
"#@markdown The fps to display the output plot. Depending on the duration of the input video, it may help to use a lower fps.\n",
"display_fps = 12 #@param\n",
"\n",
"generate_plot(\n",
" model,\n",
" video_url=video_url,\n",
" resolution=resolution,\n",
" video_fps=video_fps,\n",
" display_fps=display_fps)"
]
}
],
"metadata": {
"colab": {
"collapsed_sections": [],
"last_runtime": {
"build_target": "//learning/deepmind/dm_python:dm_notebook3",
"kind": "private"
},
"name": "plot_movinet_video_stream_predictions.ipynb",
"provenance": []
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
},
"language_info": {
"name": "python"
}
},
"nbformat": 4,
"nbformat_minor": 0
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment