{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Exporting Frozen Graphs in Tensorflow 2 \n", "In order to use a trained model as input to MIGraphX, the model must be first be saved in a frozen graph format. This was accomplished in Tensorflow 1 by launching a graph in a tf.Session and then saving the session. However, Tensorflow has decided to deprecate Sessions in favor of functions and SavedModel format. " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "After importing the necessary libraries, the next step is to instantiate a model. For simplicity, in this example we will use a resnet50 architecture with pre-trained imagenet weights. These weights may also be trained or fine-tuned before freezing. " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import tensorflow as tf\n", "tf.enable_eager_execution() #May not be required depending on tensorflow version\n", "from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2\n", "from tensorflow import keras\n", "from tensorflow.keras import layers\n", "\n", "MODEL_NAME = \"resnet50\"\n", "model = tf.keras.applications.ResNet50(weights=\"imagenet\")\n", "model.summary()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## SavedModel format\n", "The simplest way to save a model is through saved\\_model.save()\n", "\n", "This will create an equivalent tensorflow program which can later be loaded for fine-tuning or inference, although it is not directly compatible with MIGraphX." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "tf.saved_model.save(model, \"./Saved_Models/{}\".format(MODEL_NAME))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Convert to ConcreteFunction\n", "To begin, we need to get the function equivalent of the model and then concretize the function to avoid retracing." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "full_model = tf.function(lambda x: model(x))\n", "full_model = full_model.get_concrete_function(\n", " x=tf.TensorSpec(model.inputs[0].shape, model.inputs[0].dtype))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Freeze ConcreteFunction and Serialize\n", "Since we are saving the graph for the purpose of inference, all variables can be made constant (i.e. \"frozen\").\n", "\n", "Next, we need to obtain a serialized GraphDef representation of the graph. \n", "\n", "\n", "Optionally, the operators can be printed out layer by layer followed by the inputs and outputs." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "frozen_func = convert_variables_to_constants_v2(full_model)\n", "frozen_func.graph.as_graph_def()\n", "\n", "layers = [op.name for op in frozen_func.graph.get_operations()]\n", "print(\"-\" * 50)\n", "print(\"Frozen model layers: \")\n", "for layer in layers:\n", " print(layer)\n", "\n", "print(\"-\" * 50)\n", "print(\"Frozen model inputs: \")\n", "print(frozen_func.inputs)\n", "print(\"Frozen model outputs: \")\n", "print(frozen_func.outputs)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Save Frozen Graph as Protobuf\n", "Finally, we can save to hard drive, and now the frozen graph will be stored as `./frozen_models/_frozen_graph.pb`" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "tf.io.write_graph(graph_or_graph_def=frozen_func.graph,\n", " logdir=\"./frozen_models\",\n", " name=\"{}_frozen_graph.pb\".format(MODEL_NAME),\n", " as_text=False)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Assuming MIGraphX has already been built and installed on your system, the driver can be used to verify that the frozen graph has been correctly exported. " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import subprocess\n", "driver = \"/opt/rocm/bin/migraphx-driver\"\n", "command = \"read\"\n", "model_path = \"./frozen_models/{}_frozen_graph.pb\".format(MODEL_NAME)\n", "process = subprocess.run([driver, command, model_path], \n", " stdout=subprocess.PIPE, \n", " universal_newlines=True)\n", "\n", "print(process.stdout)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "tensorflow", "language": "python", "name": "tensorflow" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.9" } }, "nbformat": 4, "nbformat_minor": 4 }