Commit 9adaa571 authored by Dan Kondratyuk's avatar Dan Kondratyuk Committed by A. Unique TensorFlower
Browse files

Update MoViNet Colab tutorial and fix errors.

PiperOrigin-RevId: 434474245
parent cdccf02c
...@@ -8,7 +8,20 @@ ...@@ -8,7 +8,20 @@
"source": [ "source": [
"# MoViNet Tutorial\n", "# MoViNet Tutorial\n",
"\n", "\n",
"This notebook provides basic example code to create, build, and run [MoViNets (Mobile Video Networks)](https://arxiv.org/pdf/2103.11511.pdf). Models use TF Keras and support inference in TF 1 and TF 2. Pretrained models are provided by [TensorFlow Hub](https://tfhub.dev/google/collections/movinet/), trained on [Kinetics 600](https://deepmind.com/research/open-source/kinetics) for video action classification." "This notebook provides basic example code to build, run, and fine-tune [MoViNets (Mobile Video Networks)](https://arxiv.org/pdf/2103.11511.pdf).\n",
"\n",
"Pretrained models are provided by [TensorFlow Hub](https://tfhub.dev/google/collections/movinet/) and the [TensorFlow Model Garden](https://github.com/tensorflow/models/tree/master/official/projects/movinet), trained on [Kinetics 600](https://deepmind.com/research/open-source/kinetics) for video action classification. All Models use TensorFlow 2 with Keras for inference and training.\n",
"\n",
"The following steps will be performed:\n",
"\n",
"1. [Running base model inference with TensorFlow Hub](#scrollTo=6g0tuFvf71S9\u0026line=8\u0026uniqifier=1)\n",
"2. [Running streaming model inference with TensorFlow Hub and plotting predictions](#scrollTo=ADrHPmwGcBZ5\u0026line=4\u0026uniqifier=1)\n",
"3. [Exporting a streaming model to TensorFlow Lite for mobile](#scrollTo=W3CLHvubvdSI\u0026line=3\u0026uniqifier=1)\n",
"4. [Fine-Tuning a base Model with the TensorFlow Model Garden](#scrollTo=_s-7bEoa3f8g\u0026line=11\u0026uniqifier=1)\n",
"\n",
"![jumping jacks plot](https://storage.googleapis.com/tf_model_garden/vision/movinet/artifacts/jumpingjacks_plot.gif)\n",
"\n",
"To generate video plots like the one above, see [section 2](#scrollTo=ADrHPmwGcBZ5\u0026line=4\u0026uniqifier=1)."
] ]
}, },
{ {
...@@ -19,17 +32,9 @@ ...@@ -19,17 +32,9 @@
"source": [ "source": [
"## Setup\n", "## Setup\n",
"\n", "\n",
"It is recommended to run the models using GPUs or TPUs.\n", "For inference on smaller models (A0-A2), CPU is sufficient for this Colab. For fine-tuning, it is recommended to run the models using GPUs.\n",
"\n",
"To select a GPU/TPU in Colab, select `Runtime \u003e Change runtime type \u003e Hardware accelerator` dropdown in the top menu.\n",
"\n",
"### Install the TensorFlow Model Garden pip package\n",
"\n",
"- tf-models-official is the stable Model Garden package. Note that it may not include the latest changes in the tensorflow_models github repo.\n",
"- To include latest changes, you may install tf-models-nightly, which is the nightly Model Garden package created daily automatically.\n",
"pip will install all models and dependencies automatically.\n",
"\n", "\n",
"Install the [mediapy](https://github.com/google/mediapy) package for visualizing images/videos." "To select a GPU in Colab, select `Runtime \u003e Change runtime type \u003e Hardware accelerator \u003e GPU` dropdown in the top menu."
] ]
}, },
{ {
...@@ -40,10 +45,24 @@ ...@@ -40,10 +45,24 @@
}, },
"outputs": [], "outputs": [],
"source": [ "source": [
"!pip install -q tf-models-nightly tfds-nightly\n", "# Install packages\n",
"\n", "\n",
"# tf-models-official is the stable Model Garden package\n",
"# tf-models-nightly includes latest changes\n",
"!pip install -q tf-models-nightly\n",
"\n",
"# Install tfds nightly to download ucf101\n",
"!pip install -q tfds-nightly\n",
"\n",
"# Install the mediapy package for visualizing images/videos.\n",
"# See https://github.com/google/mediapy\n",
"!command -v ffmpeg \u003e/dev/null || (apt update \u0026\u0026 apt install -y ffmpeg)\n", "!command -v ffmpeg \u003e/dev/null || (apt update \u0026\u0026 apt install -y ffmpeg)\n",
"!pip install -q mediapy" "!pip install -q mediapy\n",
"\n",
"# Due to a bug, we reinstall opencv\n",
"# See https://stackoverflow.com/q/70537488\n",
"!pip uninstall -q -y opencv-python-headless\n",
"!pip install -q \"opencv-python-headless\u003c4.3\""
] ]
}, },
{ {
...@@ -54,22 +73,267 @@ ...@@ -54,22 +73,267 @@
}, },
"outputs": [], "outputs": [],
"source": [ "source": [
"# Run imports\n",
"import os\n", "import os\n",
"from six.moves import urllib\n",
"\n", "\n",
"import matplotlib as mpl\n",
"import matplotlib.pyplot as plt\n", "import matplotlib.pyplot as plt\n",
"import mediapy as media\n", "import mediapy as media\n",
"import numpy as np\n", "import numpy as np\n",
"from PIL import Image\n", "import PIL\n",
"import pandas as pd\n",
"import tensorflow as tf\n", "import tensorflow as tf\n",
"import tensorflow_datasets as tfds\n", "import tensorflow_datasets as tfds\n",
"import tensorflow_hub as hub\n", "import tensorflow_hub as hub\n",
"import tqdm\n",
"\n", "\n",
"from official.vision.beta.configs import video_classification\n", "mpl.rcParams.update({\n",
"from official.projects.movinet.configs import movinet as movinet_configs\n", " 'font.size': 10,\n",
"from official.projects.movinet.modeling import movinet\n", "})"
"from official.projects.movinet.modeling import movinet_layers\n", ]
"from official.projects.movinet.modeling import movinet_model" },
{
"cell_type": "markdown",
"metadata": {
"id": "OnFqOXazoWgy"
},
"source": [
"Run the cell below to define helper functions and create variables."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "dx55NK3ZoZeh"
},
"outputs": [],
"source": [
"#@title Run this cell to set up some helper code.\n",
"\n",
"# Download Kinetics 600 label map\n",
"!wget https://raw.githubusercontent.com/tensorflow/models/f8af2291cced43fc9f1d9b41ddbf772ae7b0d7d2/official/projects/movinet/files/kinetics_600_labels.txt -O labels.txt -q\n",
"\n",
"with tf.io.gfile.GFile('labels.txt') as f:\n",
" lines = f.readlines()\n",
" KINETICS_600_LABELS_LIST = [line.strip() for line in lines]\n",
" KINETICS_600_LABELS = tf.constant(KINETICS_600_LABELS_LIST)\n",
"\n",
"def get_top_k(probs, k=5, label_map=KINETICS_600_LABELS):\n",
" \"\"\"Outputs the top k model labels and probabilities on the given video.\"\"\"\n",
" top_predictions = tf.argsort(probs, axis=-1, direction='DESCENDING')[:k]\n",
" top_labels = tf.gather(label_map, top_predictions, axis=-1)\n",
" top_labels = [label.decode('utf8') for label in top_labels.numpy()]\n",
" top_probs = tf.gather(probs, top_predictions, axis=-1).numpy()\n",
" return tuple(zip(top_labels, top_probs))\n",
"\n",
"def predict_top_k(model, video, k=5, label_map=KINETICS_600_LABELS):\n",
" \"\"\"Outputs the top k model labels and probabilities on the given video.\"\"\"\n",
" outputs = model.predict(video[tf.newaxis])[0]\n",
" probs = tf.nn.softmax(outputs)\n",
" return get_top_k(probs, k=k, label_map=label_map)\n",
"\n",
"def load_movinet_from_hub(model_id, model_mode, hub_version=3):\n",
" \"\"\"Loads a MoViNet model from TF Hub.\"\"\"\n",
" hub_url = f'https://tfhub.dev/tensorflow/movinet/{model_id}/{model_mode}/kinetics-600/classification/{hub_version}'\n",
"\n",
" encoder = hub.KerasLayer(hub_url, trainable=True)\n",
"\n",
" inputs = tf.keras.layers.Input(\n",
" shape=[None, None, None, 3],\n",
" dtype=tf.float32)\n",
"\n",
" if model_mode == 'base':\n",
" inputs = dict(image=inputs)\n",
" else:\n",
" # Define the state inputs, which is a dict that maps state names to tensors.\n",
" init_states_fn = encoder.resolved_object.signatures['init_states']\n",
" state_shapes = {\n",
" name: ([s if s \u003e 0 else None for s in state.shape], state.dtype)\n",
" for name, state in init_states_fn(tf.constant([0, 0, 0, 0, 3])).items()\n",
" }\n",
" states_input = {\n",
" name: tf.keras.Input(shape[1:], dtype=dtype, name=name)\n",
" for name, (shape, dtype) in state_shapes.items()\n",
" }\n",
"\n",
" # The inputs to the model are the states and the video\n",
" inputs = {**states_input, 'image': inputs}\n",
"\n",
" # Output shape: [batch_size, 600]\n",
" outputs = encoder(inputs)\n",
"\n",
" model = tf.keras.Model(inputs, outputs)\n",
" model.build([1, 1, 1, 1, 3])\n",
"\n",
" return model\n",
"\n",
"# Download example gif\n",
"!wget https://github.com/tensorflow/models/raw/f8af2291cced43fc9f1d9b41ddbf772ae7b0d7d2/official/projects/movinet/files/jumpingjack.gif -O jumpingjack.gif -q\n",
"\n",
"def load_gif(file_path, image_size=(224, 224)):\n",
" \"\"\"Loads a gif file into a TF tensor.\"\"\"\n",
" with tf.io.gfile.GFile(file_path, 'rb') as f:\n",
" video = tf.io.decode_gif(f.read())\n",
" video = tf.image.resize(video, image_size)\n",
" video = tf.cast(video, tf.float32) / 255.\n",
" return video\n",
"\n",
"def get_top_k_streaming_labels(probs, k=5, label_map=KINETICS_600_LABELS_LIST):\n",
" \"\"\"Returns the top-k labels over an entire video sequence.\n",
"\n",
" Args:\n",
" probs: probability tensor of shape (num_frames, num_classes) that represents\n",
" the probability of each class on each frame.\n",
" k: the number of top predictions to select.\n",
" label_map: a list of labels to map logit indices to label strings.\n",
"\n",
" Returns:\n",
" a tuple of the top-k probabilities, labels, and logit indices\n",
" \"\"\"\n",
" top_categories_last = tf.argsort(probs, -1, 'DESCENDING')[-1, :1]\n",
" categories = tf.argsort(probs, -1, 'DESCENDING')[:, :k]\n",
" categories = tf.reshape(categories, [-1])\n",
"\n",
" counts = sorted([\n",
" (i.numpy(), tf.reduce_sum(tf.cast(categories == i, tf.int32)).numpy())\n",
" for i in tf.unique(categories)[0]\n",
" ], key=lambda x: x[1], reverse=True)\n",
"\n",
" top_probs_idx = tf.constant([i for i, _ in counts[:k]])\n",
" top_probs_idx = tf.concat([top_categories_last, top_probs_idx], 0)\n",
" top_probs_idx = tf.unique(top_probs_idx)[0][:k+1]\n",
"\n",
" top_probs = tf.gather(probs, top_probs_idx, axis=-1)\n",
" top_probs = tf.transpose(top_probs, perm=(1, 0))\n",
" top_labels = tf.gather(label_map, top_probs_idx, axis=0)\n",
" top_labels = [label.decode('utf8') for label in top_labels.numpy()]\n",
"\n",
" return top_probs, top_labels, top_probs_idx\n",
"\n",
"def plot_streaming_top_preds_at_step(\n",
" top_probs,\n",
" top_labels,\n",
" step=None,\n",
" image=None,\n",
" legend_loc='lower left',\n",
" duration_seconds=10,\n",
" figure_height=500,\n",
" playhead_scale=0.8,\n",
" grid_alpha=0.3):\n",
" \"\"\"Generates a plot of the top video model predictions at a given time step.\n",
"\n",
" Args:\n",
" top_probs: a tensor of shape (k, num_frames) representing the top-k\n",
" probabilities over all frames.\n",
" top_labels: a list of length k that represents the top-k label strings.\n",
" step: the current time step in the range [0, num_frames].\n",
" image: the image frame to display at the current time step.\n",
" legend_loc: the placement location of the legend.\n",
" duration_seconds: the total duration of the video.\n",
" figure_height: the output figure height.\n",
" playhead_scale: scale value for the playhead.\n",
" grid_alpha: alpha value for the gridlines.\n",
"\n",
" Returns:\n",
" A tuple of the output numpy image, figure, and axes.\n",
" \"\"\"\n",
" num_labels, num_frames = top_probs.shape\n",
" if step is None:\n",
" step = num_frames\n",
"\n",
" fig = plt.figure(figsize=(6.5, 7), dpi=300)\n",
" gs = mpl.gridspec.GridSpec(8, 1)\n",
" ax2 = plt.subplot(gs[:-3, :])\n",
" ax = plt.subplot(gs[-3:, :])\n",
"\n",
" if image is not None:\n",
" ax2.imshow(image, interpolation='nearest')\n",
" ax2.axis('off')\n",
"\n",
" preview_line_x = tf.linspace(0., duration_seconds, num_frames)\n",
" preview_line_y = top_probs\n",
"\n",
" line_x = preview_line_x[:step+1]\n",
" line_y = preview_line_y[:, :step+1]\n",
"\n",
" for i in range(num_labels):\n",
" ax.plot(preview_line_x, preview_line_y[i], label=None, linewidth='1.5',\n",
" linestyle=':', color='gray')\n",
" ax.plot(line_x, line_y[i], label=top_labels[i], linewidth='2.0')\n",
"\n",
"\n",
" ax.grid(which='major', linestyle=':', linewidth='1.0', alpha=grid_alpha)\n",
" ax.grid(which='minor', linestyle=':', linewidth='0.5', alpha=grid_alpha)\n",
"\n",
" min_height = tf.reduce_min(top_probs) * playhead_scale\n",
" max_height = tf.reduce_max(top_probs)\n",
" ax.vlines(preview_line_x[step], min_height, max_height, colors='red')\n",
" ax.scatter(preview_line_x[step], max_height, color='red')\n",
"\n",
" ax.legend(loc=legend_loc)\n",
"\n",
" plt.xlim(0, duration_seconds)\n",
" plt.ylabel('Probability')\n",
" plt.xlabel('Time (s)')\n",
" plt.yscale('log')\n",
"\n",
" fig.tight_layout()\n",
" fig.canvas.draw()\n",
"\n",
" data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)\n",
" data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))\n",
" plt.close()\n",
"\n",
" figure_width = int(figure_height * data.shape[1] / data.shape[0])\n",
" image = PIL.Image.fromarray(data).resize([figure_width, figure_height])\n",
" image = np.array(image)\n",
"\n",
" return image, (fig, ax, ax2)\n",
"\n",
"def plot_streaming_top_preds(\n",
" probs,\n",
" video,\n",
" top_k=5,\n",
" video_fps=25.,\n",
" figure_height=500,\n",
" use_progbar=True):\n",
" \"\"\"Generates a video plot of the top video model predictions.\n",
"\n",
" Args:\n",
" probs: probability tensor of shape (num_frames, num_classes) that represents\n",
" the probability of each class on each frame.\n",
" video: the video to display in the plot.\n",
" top_k: the number of top predictions to select.\n",
" video_fps: the input video fps.\n",
" figure_fps: the output video fps.\n",
" figure_height: the height of the output video.\n",
" use_progbar: display a progress bar.\n",
"\n",
" Returns:\n",
" A numpy array representing the output video.\n",
" \"\"\"\n",
" video_fps = 8.\n",
" figure_height = 500\n",
" steps = video.shape[0]\n",
" duration = steps / video_fps\n",
"\n",
" top_probs, top_labels, _ = get_top_k_streaming_labels(probs, k=top_k)\n",
"\n",
" images = []\n",
" step_generator = tqdm.trange(steps) if use_progbar else range(steps)\n",
" for i in step_generator:\n",
" image, _ = plot_streaming_top_preds_at_step(\n",
" top_probs=top_probs,\n",
" top_labels=top_labels,\n",
" step=i,\n",
" image=video[i],\n",
" duration_seconds=duration,\n",
" figure_height=figure_height,\n",
" )\n",
" images.append(image)\n",
"\n",
" return np.array(images)"
] ]
}, },
{ {
...@@ -78,95 +342,335 @@ ...@@ -78,95 +342,335 @@
"id": "6g0tuFvf71S9" "id": "6g0tuFvf71S9"
}, },
"source": [ "source": [
"## Example Usage with TensorFlow Hub\n", "## Running Base Model Inference with TensorFlow Hub\n",
"\n", "\n",
"Load MoViNet-A2-Base from TensorFlow Hub, as part of the [MoViNet collection](https://tfhub.dev/google/collections/movinet/).\n", "We will load MoViNet-A2-Base from TensorFlow Hub as part of the [MoViNet collection](https://tfhub.dev/google/collections/movinet/).\n",
"\n", "\n",
"The following code will:\n", "The following code will:\n",
"\n", "\n",
"- Load a MoViNet KerasLayer from [tfhub.dev](https://tfhub.dev).\n", "- Load a MoViNet KerasLayer from [tfhub.dev](https://tfhub.dev).\n",
"- Wrap the layer in a [Keras Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model).\n", "- Wrap the layer in a [Keras Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model).\n",
"- Load an example image, and reshape it to a single frame video.\n", "- Load an example gif as a video.\n",
"- Classify the video" "- Classify the video and print the top-5 predicted classes."
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {
"id": "nTUdhlRJzl2o" "id": "KZKKNZVBpglJ"
}, },
"outputs": [], "outputs": [],
"source": [ "source": [
"movinet_a2_hub_url = 'https://tfhub.dev/tensorflow/movinet/a2/base/kinetics-600/classification/1'\n", "model = load_movinet_from_hub('a2', 'base', hub_version=3)"
"\n", ]
"inputs = tf.keras.layers.Input(\n", },
" shape=[None, None, None, 3],\n", {
" dtype=tf.float32)\n", "cell_type": "markdown",
"metadata": {
"id": "7kU1_pL10l0B"
},
"source": [
"To provide a simple example video for classification, we can load a short gif of jumping jacks being performed.\n",
"\n", "\n",
"encoder = hub.KerasLayer(movinet_a2_hub_url, trainable=True)\n", "![jumping jacks](https://github.com/tensorflow/models/raw/f8af2291cced43fc9f1d9b41ddbf772ae7b0d7d2/official/projects/movinet/files/jumpingjack.gif)\n",
"\n", "\n",
"# Important: To use tf.nn.conv3d on CPU, we must compile with tf.function.\n", "Attribution: Footage shared by [Coach Bobby Bluford](https://www.youtube.com/watch?v=-AxHpj-EuPg) on YouTube under the CC-BY license."
"encoder.call = tf.function(encoder.call, experimental_compile=True)\n", ]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Iy0rKRrT723_"
},
"outputs": [],
"source": [
"video = load_gif('jumpingjack.gif', image_size=(172, 172))\n",
"\n", "\n",
"# [batch_size, 600]\n", "# Show video\n",
"outputs = encoder(dict(image=inputs))\n", "print(video.shape)\n",
"media.show_video(video.numpy(), fps=5)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "P0bZfrAsqPv2",
"outputId": "bd82571f-8dfd-4faf-ed10-e34708b0405d"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"jumping jacks 0.9166437\n",
"zumba 0.016020728\n",
"doing aerobics 0.008053946\n",
"dancing charleston 0.006083599\n",
"lunge 0.0035062772\n"
]
}
],
"source": [
"# Run the model on the video and output the top 5 predictions\n",
"outputs = predict_top_k(model, video)\n",
"\n", "\n",
"model = tf.keras.Model(inputs, outputs)" "for label, prob in outputs:\n",
" print(label, prob)"
] ]
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"id": "7kU1_pL10l0B" "id": "ADrHPmwGcBZ5"
}, },
"source": [ "source": [
"To provide a simple example video for classification, we can load a static image and reshape it to produce a video with a single frame." "## Run Streaming Model Inference with TensorFlow Hub and Plot Predictions\n",
"\n",
"We will load MoViNet-A0-Stream from TensorFlow Hub as part of the [MoViNet collection](https://tfhub.dev/google/collections/movinet/).\n",
"\n",
"The following code will:\n",
"\n",
"- Load a MoViNet model from [tfhub.dev](https://tfhub.dev).\n",
"- Classify an example video and plot the streaming predictions over time."
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {
"id": "Iy0rKRrT723_" "id": "tXWR13wthnK5"
}, },
"outputs": [], "outputs": [],
"source": [ "source": [
"image_url = 'https://upload.wikimedia.org/wikipedia/commons/8/84/Ski_Famille_-_Family_Ski_Holidays.jpg'\n", "model = load_movinet_from_hub('a2', 'stream', hub_version=3)\n",
"image_height = 224\n", "\n",
"image_width = 224\n", "# Create initial states for the stream model\n",
"init_states_fn = model.layers[-1].resolved_object.signatures['init_states']\n",
"init_states = init_states_fn(tf.shape(video[tf.newaxis]))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "YqSkt7l8ltwt",
"outputId": "6ccf1dd6-95d1-43b1-efdb-2e931dd3a19d"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"100%|██████████| 13/13 [00:08\u003c00:00, 1.58it/s]\n",
"jumping jacks 0.9998123\n",
"zumba 0.00011835508\n",
"doing aerobics 3.3375818e-05\n",
"dancing charleston 4.9819987e-06\n",
"finger snapping 3.8673647e-06\n"
]
}
],
"source": [
"# Insert your video clip here\n",
"video = load_gif('jumpingjack.gif', image_size=(172, 172))\n",
"clips = tf.split(video[tf.newaxis], video.shape[0], axis=1)\n",
"\n",
"all_logits = []\n",
"\n", "\n",
"with urllib.request.urlopen(image_url) as f:\n", "# To run on a video, pass in one frame at a time\n",
" image = Image.open(f).resize((image_height, image_width))\n", "states = init_states\n",
"video = tf.reshape(np.array(image), [1, 1, image_height, image_width, 3])\n", "for clip in tqdm.tqdm(clips):\n",
"video = tf.cast(video, tf.float32) / 255.\n", " # Input shape: [1, 1, 172, 172, 3]\n",
" logits, states = model.predict({**states, 'image': clip}, verbose=0)\n",
" all_logits.append(logits)\n",
"\n", "\n",
"image" "logits = tf.concat(all_logits, 0)\n",
"probs = tf.nn.softmax(logits)\n",
"\n",
"final_probs = probs[-1]\n",
"top_k = get_top_k(final_probs)\n",
"print()\n",
"for label, prob in top_k:\n",
" print(label, prob)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Xdox556CtMRb"
},
"outputs": [],
"source": [
"# Generate a plot and output to a video tensor\n",
"plot_video = plot_streaming_top_preds(probs, video, video_fps=8.)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "NSStKE9klCs3"
},
"outputs": [],
"source": [
"# For gif format, set codec='gif'\n",
"media.show_video(plot_video, fps=3)"
] ]
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"id": "Yf6EefHuWfxC" "id": "W3CLHvubvdSI"
}, },
"source": [ "source": [
"Run the model and output the predicted label. Expected output should be skiing (labels 464-467). E.g., 465 = \"skiing crosscountry\".\n", "## Export a Streaming Model to TensorFlow Lite for Mobile\n",
"\n", "\n",
"See [here](https://gist.github.com/willprice/f19da185c9c5f32847134b87c1960769#file-kinetics_600_labels-csv) for a full list of all labels." "We will convert a MoViNet-A0-Stream model to [TensorFlow Lite](https://www.tensorflow.org/lite).\n",
"\n",
"The following code will:\n",
"- Load a MoViNet-A0-Stream model.\n",
"- Convert the model to TF Lite.\n",
"- Run inference on an example video using the Python interpreter."
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {
"id": "OOpEKuqH8sH7" "id": "KH0j-07KVh06"
}, },
"outputs": [], "outputs": [],
"source": [ "source": [
"output = model(video)\n", "# Run imports\n",
"output_label_index = tf.argmax(output, -1)[0].numpy()\n", "from official.vision.configs import video_classification\n",
"from official.projects.movinet.configs import movinet as movinet_configs\n",
"from official.projects.movinet.modeling import movinet\n",
"from official.projects.movinet.modeling import movinet_layers\n",
"from official.projects.movinet.modeling import movinet_model\n",
"from official.projects.movinet.tools import export_saved_model"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "RLkV0xtPvfkY"
},
"outputs": [],
"source": [
"# Export to saved model\n",
"saved_model_dir = 'model'\n",
"tflite_filename = 'model.tflite'\n",
"input_shape = [1, 1, 172, 172, 3]\n",
"batch_size, num_frames, image_size, = input_shape[:3]\n",
"\n", "\n",
"print(output_label_index)" "tf.keras.backend.clear_session()\n",
"\n",
"# Create the model\n",
"input_specs = tf.keras.layers.InputSpec(shape=input_shape)\n",
"backbone = movinet.Movinet(\n",
" model_id='a0',\n",
" causal=True,\n",
" conv_type='2plus1d',\n",
" se_type='2plus3d',\n",
" input_specs=input_specs,\n",
" activation='hard_swish',\n",
" gating_activation='hard_sigmoid',\n",
" use_sync_bn=False,\n",
" use_external_states=True)\n",
"model = movinet_model.MovinetClassifier(\n",
" backbone=backbone,\n",
" activation='hard_swish',\n",
" num_classes=600,\n",
" output_states=True,\n",
" input_specs=dict(image=input_specs))\n",
"model.build([1, 1, 1, 1, 3])\n",
"\n",
"# Extract pretrained weights\n",
"!wget https://storage.googleapis.com/tf_model_garden/vision/movinet/movinet_a0_stream.tar.gz -O movinet_a0_stream.tar.gz -q\n",
"!tar -xvf movinet_a0_stream.tar.gz\n",
"\n",
"checkpoint_dir = 'movinet_a0_stream'\n",
"checkpoint_path = tf.train.latest_checkpoint(checkpoint_dir)\n",
"\n",
"# Convert to saved model\n",
"export_saved_model.export_saved_model(\n",
" model=model,\n",
" input_shape=input_shape,\n",
" export_path=saved_model_dir,\n",
" causal=True,\n",
" bundle_input_init_states_fn=False,\n",
" checkpoint_path=checkpoint_path)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "gPg_6eMC8IwF"
},
"outputs": [],
"source": [
"# Convert to TF Lite\n",
"converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)\n",
"tflite_model = converter.convert()\n",
"\n",
"with open(tflite_filename, 'wb') as f:\n",
" f.write(tflite_model)\n",
"\n",
"# Create the interpreter and signature runner\n",
"interpreter = tf.lite.Interpreter(model_path=tflite_filename)\n",
"runner = interpreter.get_signature_runner()\n",
"\n",
"init_states = {\n",
" name: tf.zeros(x['shape'], dtype=x['dtype'])\n",
" for name, x in runner.get_input_details().items()\n",
"}\n",
"del init_states['image']"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "-TQ-7oSJIlTA",
"outputId": "a15519ff-d08c-40bc-fbea-d3a58169450c"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"jumping jacks 0.9791285\n",
"jogging 0.0019550633\n",
"riding unicycle 0.0017429002\n",
"passing soccer ball 0.0016952101\n",
"stretching arm 0.0014458151\n"
]
}
],
"source": [
"# Insert your video clip here\n",
"video = load_gif('jumpingjack.gif', image_size=(172, 172))\n",
"clips = tf.split(video[tf.newaxis], video.shape[0], axis=1)\n",
"\n",
"# To run on a video, pass in one frame at a time\n",
"states = init_states\n",
"for clip in clips:\n",
" # Input shape: [1, 1, 172, 172, 3]\n",
" outputs = runner(**states, image=clip)\n",
" logits = outputs.pop('logits')[0]\n",
" states = outputs\n",
"\n",
"probs = tf.nn.softmax(logits)\n",
"top_k = get_top_k(probs)\n",
"print()\n",
"for label, prob in top_k:\n",
" print(label, prob)"
] ]
}, },
{ {
...@@ -175,17 +679,17 @@ ...@@ -175,17 +679,17 @@
"id": "_s-7bEoa3f8g" "id": "_s-7bEoa3f8g"
}, },
"source": [ "source": [
"## Example Usage with the TensorFlow Model Garden\n", "## Fine-Tune a Base Model with the TensorFlow Model Garden\n",
"\n", "\n",
"Fine-tune MoViNet-A0-Base on [UCF-101](https://www.crcv.ucf.edu/research/data-sets/ucf101/).\n", "We will Fine-tune MoViNet-A0-Base on [UCF-101](https://www.crcv.ucf.edu/research/data-sets/ucf101/).\n",
"\n", "\n",
"The following code will:\n", "The following code will:\n",
"\n", "\n",
"- Load the UCF-101 dataset with [TensorFlow Datasets](https://www.tensorflow.org/datasets/catalog/ucf101).\n", "- Load the UCF-101 dataset with [TensorFlow Datasets](https://www.tensorflow.org/datasets/catalog/ucf101).\n",
"- Create a [`tf.data.Dataset`](https://www.tensorflow.org/api_docs/python/tf/data/Dataset) pipeline for training and evaluation.\n", "- Create a simple [`tf.data.Dataset`](https://www.tensorflow.org/api_docs/python/tf/data/Dataset) pipeline for training and evaluation.\n",
"- Display some example videos from the dataset.\n", "- Display some example videos from the dataset.\n",
"- Build a MoViNet model and load pretrained weights.\n", "- Build a MoViNet model and load pretrained weights.\n",
"- Fine-tune the final classifier layers on UCF-101." "- Fine-tune the final classifier layers on UCF-101 and evaluate accuracy on the validation set."
] ]
}, },
{ {
...@@ -196,7 +700,25 @@ ...@@ -196,7 +700,25 @@
"source": [ "source": [
"### Load the UCF-101 Dataset with TensorFlow Datasets\n", "### Load the UCF-101 Dataset with TensorFlow Datasets\n",
"\n", "\n",
"Calling `download_and_prepare()` will automatically download the dataset. After downloading, this cell will output information about the dataset." "Calling `download_and_prepare()` will automatically download the dataset. This step may take up to 1 hour depending on the download and extraction speed. After downloading, the next cell will output information about the dataset."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "2IHLbPAfrs5P"
},
"outputs": [],
"source": [
"# Run imports\n",
"import tensorflow_datasets as tfds\n",
"\n",
"from official.vision.configs import video_classification\n",
"from official.projects.movinet.configs import movinet as movinet_configs\n",
"from official.projects.movinet.modeling import movinet\n",
"from official.projects.movinet.modeling import movinet_layers\n",
"from official.projects.movinet.modeling import movinet_model"
] ]
}, },
{ {
...@@ -288,7 +810,7 @@ ...@@ -288,7 +810,7 @@
")" ")"
] ]
}, },
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"tags": [] "tags": []
}, },
...@@ -310,15 +832,6 @@ ...@@ -310,15 +832,6 @@
"builder.info" "builder.info"
] ]
}, },
{
"cell_type": "markdown",
"metadata": {
"id": "BsJJgnBBqDKZ"
},
"source": [
"Build the training and evaluation datasets."
]
},
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
...@@ -327,6 +840,8 @@ ...@@ -327,6 +840,8 @@
}, },
"outputs": [], "outputs": [],
"source": [ "source": [
"# Build the training and evaluation datasets.\n",
"\n",
"batch_size = 8\n", "batch_size = 8\n",
"num_frames = 8\n", "num_frames = 8\n",
"frame_stride = 10\n", "frame_stride = 10\n",
...@@ -392,16 +907,9 @@ ...@@ -392,16 +907,9 @@
"id": "R3RHeuHdsd_3" "id": "R3RHeuHdsd_3"
}, },
"source": [ "source": [
"### Build MoViNet-A0-Base and Load Pretrained Weights" "### Build MoViNet-A0-Base and Load Pretrained Weights\n",
] "\n",
}, "Here we create a MoViNet model using the open source code provided in [official/projects/movinet](https://github.com/tensorflow/models/tree/master/official/projects/movinet) and load the pretrained weights. Here we freeze the all layers except the final classifier head to speed up fine-tuning."
{
"cell_type": "markdown",
"metadata": {
"id": "JXVQOP9Rqk0I"
},
"source": [
"Here we create a MoViNet model using the open source code provided in [tensorflow/models](https://github.com/tensorflow/models) and load the pretrained weights. Here we freeze the all layers except the final classifier head to speed up fine-tuning."
] ]
}, },
{ {
...@@ -416,32 +924,38 @@ ...@@ -416,32 +924,38 @@
"\n", "\n",
"tf.keras.backend.clear_session()\n", "tf.keras.backend.clear_session()\n",
"\n", "\n",
"backbone = movinet.Movinet(\n", "backbone = movinet.Movinet(model_id=model_id)\n",
" model_id=model_id)\n", "model = movinet_model.MovinetClassifier(backbone=backbone, num_classes=600)\n",
"model = movinet_model.MovinetClassifier(\n", "model.build([1, 1, 1, 1, 3])\n",
" backbone=backbone,\n",
" num_classes=600)\n",
"model.build([batch_size, num_frames, resolution, resolution, 3])\n",
"\n", "\n",
"# Load pretrained weights from TF Hub\n", "# Load pretrained weights\n",
"movinet_hub_url = f'https://tfhub.dev/tensorflow/movinet/{model_id}/base/kinetics-600/classification/1'\n", "!wget https://storage.googleapis.com/tf_model_garden/vision/movinet/movinet_a0_base.tar.gz -O movinet_a0_base.tar.gz -q\n",
"movinet_hub_model = hub.KerasLayer(movinet_hub_url, trainable=True)\n", "!tar -xvf movinet_a0_base.tar.gz\n",
"pretrained_weights = {w.name: w for w in movinet_hub_model.weights}\n",
"model_weights = {w.name: w for w in model.weights}\n",
"for name in pretrained_weights:\n",
" model_weights[name].assign(pretrained_weights[name])\n",
"\n", "\n",
"# Wrap the backbone with a new classifier to create a new classifier head\n", "checkpoint_dir = 'movinet_a0_base'\n",
"# with num_classes outputs\n", "checkpoint_path = tf.train.latest_checkpoint(checkpoint_dir)\n",
"model = movinet_model.MovinetClassifier(\n", "checkpoint = tf.train.Checkpoint(model=model)\n",
" backbone=backbone,\n", "status = checkpoint.restore(checkpoint_path)\n",
" num_classes=num_classes)\n", "status.assert_existing_objects_matched()\n",
"model.build([batch_size, num_frames, resolution, resolution, 3])\n", "\n",
"def build_classifier(backbone, num_classes, freeze_backbone=False):\n",
" \"\"\"Builds a classifier on top of a backbone model.\"\"\"\n",
" model = movinet_model.MovinetClassifier(\n",
" backbone=backbone,\n",
" num_classes=num_classes)\n",
" model.build([batch_size, num_frames, resolution, resolution, 3])\n",
"\n", "\n",
"# Freeze all layers except for the final classifier head\n", " if freeze_backbone:\n",
"for layer in model.layers[:-1]:\n", " for layer in model.layers[:-1]:\n",
" layer.trainable = False\n", " layer.trainable = False\n",
"model.layers[-1].trainable = True" " model.layers[-1].trainable = True\n",
"\n",
" return model\n",
"\n",
"# Wrap the backbone with a new classifier to create a new classifier head\n",
"# with num_classes outputs (101 classes for UCF101).\n",
"# Freeze all layers except for the final classifier head.\n",
"model = build_classifier(backbone, num_classes, freeze_backbone=True)"
] ]
}, },
{ {
...@@ -500,7 +1014,7 @@ ...@@ -500,7 +1014,7 @@
"id": "0IyAOOlcpHna" "id": "0IyAOOlcpHna"
}, },
"source": [ "source": [
"Run the fine-tuning with Keras compile/fit. After fine-tuning the model, we should be able to achieve \u003e70% accuracy on the test set." "Run the fine-tuning with Keras compile/fit. After fine-tuning the model, we should be able to achieve \u003e85% accuracy on the test set."
] ]
}, },
{ {
...@@ -527,11 +1041,11 @@ ...@@ -527,11 +1041,11 @@
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"Epoch 1/3\n", "Epoch 1/3\n",
"1192/1192 [==============================] - 348s 286ms/step - loss: 3.4914 - top_1: 0.3639 - top_5: 0.6294 - val_loss: 2.5153 - val_top_1: 0.5975 - val_top_5: 0.8565\n", "1192/1192 [==============================] - 551s 451ms/step - loss: 2.5050 - top_1: 0.6692 - top_5: 0.8753 - val_loss: 1.6310 - val_top_1: 0.8109 - val_top_5: 0.9701\n",
"Epoch 2/3\n", "Epoch 2/3\n",
"1192/1192 [==============================] - 286s 240ms/step - loss: 2.1397 - top_1: 0.6794 - top_5: 0.9231 - val_loss: 2.0695 - val_top_1: 0.6838 - val_top_5: 0.9070\n", "1192/1192 [==============================] - 533s 447ms/step - loss: 1.3336 - top_1: 0.9024 - top_5: 0.9906 - val_loss: 1.4576 - val_top_1: 0.8451 - val_top_5: 0.9740\n",
"Epoch 3/3\n", "Epoch 3/3\n",
"1192/1192 [==============================] - 348s 292ms/step - loss: 1.8925 - top_1: 0.7660 - top_5: 0.9454 - val_loss: 1.9848 - val_top_1: 0.7116 - val_top_5: 0.9227\n" "1192/1192 [==============================] - 531s 446ms/step - loss: 1.2298 - top_1: 0.9329 - top_5: 0.9943 - val_loss: 1.4351 - val_top_1: 0.8514 - val_top_5: 0.9762\n"
] ]
} }
], ],
...@@ -573,7 +1087,7 @@ ...@@ -573,7 +1087,7 @@
"colab": { "colab": {
"collapsed_sections": [], "collapsed_sections": [],
"last_runtime": { "last_runtime": {
"build_target": "//learning/deepmind/public/tools/ml_python:ml_notebook", "build_target": "//learning/deepmind/dm_python:dm_notebook3",
"kind": "private" "kind": "private"
}, },
"name": "movinet_tutorial.ipynb", "name": "movinet_tutorial.ipynb",
......
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
...@@ -51,6 +51,8 @@ python3 export_saved_model.py \ ...@@ -51,6 +51,8 @@ python3 export_saved_model.py \
To use an exported saved_model, refer to export_saved_model_test.py. To use an exported saved_model, refer to export_saved_model_test.py.
""" """
from typing import Optional, Tuple
from absl import app from absl import app
from absl import flags from absl import flags
import tensorflow as tf import tensorflow as tf
...@@ -113,62 +115,50 @@ flags.DEFINE_string( ...@@ -113,62 +115,50 @@ flags.DEFINE_string(
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
def main(_) -> None: def export_saved_model(
input_specs = tf.keras.layers.InputSpec(shape=[ model: tf.keras.Model,
FLAGS.batch_size, input_shape: Tuple[int, int, int, int, int],
FLAGS.num_frames, export_path: str = '/tmp/movinet/',
FLAGS.image_size, causal: bool = False,
FLAGS.image_size, bundle_input_init_states_fn: bool = True,
3, checkpoint_path: Optional[str] = None) -> None:
]) """Exports a MoViNet model to a saved model.
Args:
model: the tf.keras.Model to export.
input_shape: The 5D spatiotemporal input shape of size
[batch_size, num_frames, image_height, image_width, num_channels].
Set the field or a shape position in the field to None for dynamic input.
export_path: Export path to save the saved_model file.
causal: Run the model in causal mode.
bundle_input_init_states_fn: Add init_states as a function signature to the
saved model. This is not necessary if the input shape is static (e.g.,
for TF Lite).
checkpoint_path: Checkpoint path to load. Leave blank to keep the model's
initialization.
"""
# Use dimensions of 1 except the channels to export faster, # Use dimensions of 1 except the channels to export faster,
# since we only really need the last dimension to build and get the output # since we only really need the last dimension to build and get the output
# states. These dimensions can be set to `None` once the model is built. # states. These dimensions can be set to `None` once the model is built.
input_shape = [1 if s is None else s for s in input_specs.shape] input_shape_concrete = [1 if s is None else s for s in input_shape]
model.build(input_shape_concrete)
# Override swish activation implementation to remove custom gradients
activation = FLAGS.activation
if activation == 'swish':
activation = 'simple_swish'
classifier_activation = FLAGS.classifier_activation
if classifier_activation == 'swish':
classifier_activation = 'simple_swish'
backbone = movinet.Movinet(
model_id=FLAGS.model_id,
causal=FLAGS.causal,
use_positional_encoding=FLAGS.use_positional_encoding,
conv_type=FLAGS.conv_type,
se_type=FLAGS.se_type,
input_specs=input_specs,
activation=activation,
gating_activation=FLAGS.gating_activation,
use_sync_bn=False,
use_external_states=FLAGS.causal)
model = movinet_model.MovinetClassifier(
backbone,
num_classes=FLAGS.num_classes,
output_states=FLAGS.causal,
input_specs=dict(image=input_specs),
activation=classifier_activation)
model.build(input_shape)
# Compile model to generate some internal Keras variables. # Compile model to generate some internal Keras variables.
model.compile() model.compile()
if FLAGS.checkpoint_path: if checkpoint_path:
checkpoint = tf.train.Checkpoint(model=model) checkpoint = tf.train.Checkpoint(model=model)
status = checkpoint.restore(FLAGS.checkpoint_path) status = checkpoint.restore(checkpoint_path)
status.assert_existing_objects_matched() status.assert_existing_objects_matched()
if FLAGS.causal: if causal:
# Call the model once to get the output states. Call again with `states` # Call the model once to get the output states. Call again with `states`
# input to ensure that the inputs with the `states` argument is built # input to ensure that the inputs with the `states` argument is built
# with the full output state shapes. # with the full output state shapes.
input_image = tf.ones(input_shape) input_image = tf.ones(input_shape_concrete)
_, states = model({**model.init_states(input_shape), 'image': input_image}) _, states = model({
**model.init_states(input_shape_concrete), 'image': input_image})
_ = model({**states, 'image': input_image}) _ = model({**states, 'image': input_image})
# Create a function to explicitly set the names of the outputs # Create a function to explicitly set the names of the outputs
...@@ -179,10 +169,10 @@ def main(_) -> None: ...@@ -179,10 +169,10 @@ def main(_) -> None:
specs = { specs = {
name: tf.TensorSpec(spec.shape, name=name, dtype=spec.dtype) name: tf.TensorSpec(spec.shape, name=name, dtype=spec.dtype)
for name, spec in model.initial_state_specs( for name, spec in model.initial_state_specs(
input_specs.shape).items() input_shape).items()
} }
specs['image'] = tf.TensorSpec( specs['image'] = tf.TensorSpec(
input_specs.shape, dtype=model.dtype, name='image') input_shape, dtype=model.dtype, name='image')
predict_fn = tf.function(predict, jit_compile=True) predict_fn = tf.function(predict, jit_compile=True)
predict_fn = predict_fn.get_concrete_function(specs) predict_fn = predict_fn.get_concrete_function(specs)
...@@ -191,17 +181,118 @@ def main(_) -> None: ...@@ -191,17 +181,118 @@ def main(_) -> None:
init_states_fn = init_states_fn.get_concrete_function( init_states_fn = init_states_fn.get_concrete_function(
tf.TensorSpec([5], dtype=tf.int32)) tf.TensorSpec([5], dtype=tf.int32))
if FLAGS.bundle_input_init_states_fn: if bundle_input_init_states_fn:
signatures = {'call': predict_fn, 'init_states': init_states_fn} signatures = {'call': predict_fn, 'init_states': init_states_fn}
else: else:
signatures = predict_fn signatures = predict_fn
tf.keras.models.save_model( tf.keras.models.save_model(
model, FLAGS.export_path, signatures=signatures) model, export_path, signatures=signatures)
else: else:
_ = model(tf.ones(input_shape)) _ = model(tf.ones(input_shape_concrete))
tf.keras.models.save_model(model, FLAGS.export_path) tf.keras.models.save_model(model, export_path)
def build_and_export_saved_model(
export_path: str = '/tmp/movinet/',
model_id: str = 'a0',
causal: bool = False,
conv_type: str = '3d',
se_type: str = '3d',
activation: str = 'swish',
classifier_activation: str = 'swish',
gating_activation: str = 'sigmoid',
use_positional_encoding: bool = False,
num_classes: int = 600,
input_shape: Optional[Tuple[int, int, int, int, int]] = None,
bundle_input_init_states_fn: bool = True,
checkpoint_path: Optional[str] = None) -> None:
"""Builds and exports a MoViNet model to a saved model.
Args:
export_path: Export path to save the saved_model file.
model_id: MoViNet model name.
causal: Run the model in causal mode.
conv_type: 3d, 2plus1d, or 3d_2plus1d. 3d configures the network
to use the default 3D convolution. 2plus1d uses (2+1)D convolution
with Conv2D operations and 2D reshaping (e.g., a 5x3x3 kernel becomes
3x3 followed by 5x1 conv). 3d_2plus1d uses (2+1)D convolution with
Conv3D and no 2D reshaping (e.g., a 5x3x3 kernel becomes 1x3x3
followed by 5x1x1 conv).
se_type:
3d, 2d, or 2plus3d. 3d uses the default 3D spatiotemporal global average
pooling for squeeze excitation. 2d uses 2D spatial global average pooling
on each frame. 2plus3d concatenates both 3D and 2D global average
pooling.
activation: The main activation to use across layers.
classifier_activation: The classifier activation to use.
gating_activation: The gating activation to use in squeeze-excitation
layers.
use_positional_encoding: Whether to use positional encoding (only applied
when causal=True).
num_classes: The number of classes for prediction.
input_shape: The 5D spatiotemporal input shape of size
[batch_size, num_frames, image_height, image_width, num_channels].
Set the field or a shape position in the field to None for dynamic input.
bundle_input_init_states_fn: Add init_states as a function signature to the
saved model. This is not necessary if the input shape is static (e.g.,
for TF Lite).
checkpoint_path: Checkpoint path to load. Leave blank for default
initialization.
"""
input_specs = tf.keras.layers.InputSpec(shape=input_shape)
# Override swish activation implementation to remove custom gradients
if activation == 'swish':
activation = 'simple_swish'
if classifier_activation == 'swish':
classifier_activation = 'simple_swish'
backbone = movinet.Movinet(
model_id=model_id,
causal=causal,
use_positional_encoding=use_positional_encoding,
conv_type=conv_type,
se_type=se_type,
input_specs=input_specs,
activation=activation,
gating_activation=gating_activation,
use_sync_bn=False,
use_external_states=causal)
model = movinet_model.MovinetClassifier(
backbone,
num_classes=num_classes,
output_states=causal,
input_specs=dict(image=input_specs),
activation=classifier_activation)
export_saved_model(
model=model,
input_shape=input_shape,
export_path=export_path,
causal=causal,
bundle_input_init_states_fn=bundle_input_init_states_fn,
checkpoint_path=checkpoint_path)
def main(_) -> None:
input_shape = (
FLAGS.batch_size, FLAGS.num_frames, FLAGS.image_size, FLAGS.image_size, 3)
build_and_export_saved_model(
export_path=FLAGS.export_path,
model_id=FLAGS.model_id,
causal=FLAGS.causal,
conv_type=FLAGS.conv_type,
se_type=FLAGS.se_type,
activation=FLAGS.activation,
classifier_activation=FLAGS.classifier_activation,
gating_activation=FLAGS.gating_activation,
use_positional_encoding=FLAGS.use_positional_encoding,
num_classes=FLAGS.num_classes,
input_shape=input_shape,
bundle_input_init_states_fn=FLAGS.bundle_input_init_states_fn,
checkpoint_path=FLAGS.checkpoint_path)
print(' ----- Done. Saved Model is saved at {}'.format(FLAGS.export_path)) print(' ----- Done. Saved Model is saved at {}'.format(FLAGS.export_path))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment