Commit 743c28aa authored by Scott Main's avatar Scott Main Committed by TF Object Detection Team
Browse files

Generalize the xml parsing to facilitate custom datasets.

Instead of searching the annotation files based on class names in the filenames, read all the XML annotation files. For the cats/dogs use-case, just delete all the XML files that are not the breeds we care about.

PiperOrigin-RevId: 410613750
parent 3c22c16d
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "Generate_SSD_anchor_box_aspect_ratios_using_k_means_clustering.ipynb",
"provenance": [],
"collapsed_sections": []
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
}
},
"cells": [
{
"cell_type": "markdown",
......@@ -55,20 +39,22 @@
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "hCQlBGJkZTR2"
},
"outputs": [],
"source": [
"import tensorflow as tf"
],
"execution_count": null,
"outputs": []
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "aw-Ba-5RUhMs"
},
"outputs": [],
"source": [
"# Install the tensorflow Object Detection API...\n",
"# If you're running this offline, you also might need to install the protobuf-compiler:\n",
......@@ -87,9 +73,7 @@
"\n",
"# Test the installation\n",
"! python object_detection/builders/model_builder_tf2_test.py"
],
"execution_count": null,
"outputs": []
]
},
{
"cell_type": "markdown",
......@@ -113,19 +97,21 @@
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "sKYfhq7CKZ4B"
},
"outputs": [],
"source": [
"%mkdir /content/dataset\n",
"%cd /content/dataset\n",
"! wget http://www.robots.ox.ac.uk/~vgg/data/pets/data/images.tar.gz\n",
"! wget http://www.robots.ox.ac.uk/~vgg/data/pets/data/annotations.tar.gz\n",
"! tar zxf images.tar.gz\n",
"! tar zxf annotations.tar.gz"
],
"execution_count": null,
"outputs": []
"! tar zxf annotations.tar.gz\n",
"\n",
"XML_PATH = '/content/dataset/annotations/xmls'"
]
},
{
"cell_type": "markdown",
......@@ -133,28 +119,53 @@
"id": "44vtL0nsAqXg"
},
"source": [
"In this case, we want to reduce the PETS dataset to match the collection of cats and dogs used to train the model (in [this training notebook](https://colab.sandbox.google.com/github/google-coral/tutorials/blob/master/retrain_ssdlite_mobiledet_qat_tf1.ipynb)):\n",
"Because the following k-means script will process all XML annotations, we want to reduce the PETS dataset to include only the cats and dogs used to train the model (in [this training notebook](https://colab.sandbox.google.com/github/google-coral/tutorials/blob/master/retrain_ssdlite_mobiledet_qat_tf1.ipynb)). So we delete all annotation files that are **not** Abyssinian or American bulldog:\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "8gcUoBU2K_s7"
"id": "ih48zFbl6jM7"
},
"outputs": [],
"source": [
"! cp /content/dataset/annotations/list.txt /content/dataset/annotations/list_petsdataset.txt\n",
"! cp /content/dataset/annotations/trainval.txt /content/dataset/annotations/trainval_petsdataset.txt\n",
"! cp /content/dataset/annotations/test.txt /content/dataset/annotations/test_petsdataset.txt\n",
"! grep \"Abyssinian\" /content/dataset/annotations/list_petsdataset.txt > /content/dataset/annotations/list.txt\n",
"! grep \"american_bulldog\" /content/dataset/annotations/list_petsdataset.txt >> /content/dataset/annotations/list.txt\n",
"! grep \"Abyssinian\" /content/dataset/annotations/trainval_petsdataset.txt > /content/dataset/annotations/trainval.txt\n",
"! grep \"american_bulldog\" /content/dataset/annotations/trainval_petsdataset.txt >> /content/dataset/annotations/trainval.txt\n",
"! grep \"Abyssinian\" /content/dataset/annotations/test_petsdataset.txt > /content/dataset/annotations/test.txt\n",
"! grep \"american_bulldog\" /content/dataset/annotations/test_petsdataset.txt >> /content/dataset/annotations/test.txt"
],
"! (cd /content/dataset/annotations/xmls/ \u0026\u0026 \\\n",
" find . ! \\( -name 'Abyssinian*' -o -name 'american_bulldog*' \\) -type f -exec rm -f {} \\; )"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "KG8uraCK-RSM"
},
"source": [
"### Upload your own dataset"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "m0bh_iKD-Xz4"
},
"source": [
"To generate the anchor box ratios for your own dataset, upload a ZIP file with your annotation files (click the **Files** tab on the left, and drag-drop your ZIP file there), and then uncomment the following code to unzip it and specify the path to the directory with your annotation files:"
]
},
{
"cell_type": "code",
"execution_count": null,
"outputs": []
"metadata": {
"id": "M0j_vWDR3WkK"
},
"outputs": [],
"source": [
"# %cd /content/\n",
"# !unzip dataset.zip\n",
"\n",
"# XML_PATH = '/content/dataset/annotations/xmls'"
]
},
{
"cell_type": "markdown",
......@@ -188,23 +199,24 @@
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "vCB8Dfs0Xlyv"
},
"outputs": [],
"source": [
"import sys\n",
"import glob\n",
"import os\n",
"import numpy as np\n",
"import xml.etree.ElementTree as ET\n",
"\n",
"from sklearn.cluster import KMeans\n",
"\n",
"def xml_to_boxes(path, classes, rescale_width=None, rescale_height=None):\n",
"def xml_to_boxes(path, rescale_width=None, rescale_height=None):\n",
" \"\"\"Extracts bounding-box widths and heights from ground-truth dataset.\n",
"\n",
" Args:\n",
" path : Path to .xml annotation files for your dataset.\n",
" classes : List of classes that are part of dataset.\n",
" rescale_width : Scaling factor to rescale width of bounding box.\n",
" rescale_height : Scaling factor to rescale height of bounding box.\n",
"\n",
......@@ -213,23 +225,20 @@
" \"\"\"\n",
"\n",
" xml_list = []\n",
" for clss in classes:\n",
" for xml_file in glob.glob(path + '/'+clss+'*'):\n",
" if xml_file.endswith('.xml'):\n",
" tree = ET.parse(xml_file)\n",
" root = tree.getroot()\n",
" for member in root.findall('object'):\n",
" bndbox = member.find('bndbox')\n",
" bbox_width = int(bndbox.find('xmax').text) - int(bndbox.find('xmin').text)\n",
" bbox_height = int(bndbox.find('ymax').text) - int(bndbox.find('ymin').text)\n",
" if rescale_width and rescale_height:\n",
" size = root.find('size')\n",
" bbox_width = bbox_width * (rescale_width / int(size.find('width').text))\n",
" bbox_height = bbox_height * (rescale_height / int(size.find('height').text))\n",
"\n",
" xml_list.append([bbox_width, bbox_height])\n",
" else:\n",
" continue\n",
" filenames = os.listdir(os.path.join(path))\n",
" filenames = [os.path.join(path, f) for f in filenames if (f.endswith('.xml'))]\n",
" for xml_file in filenames:\n",
" tree = ET.parse(xml_file)\n",
" root = tree.getroot()\n",
" for member in root.findall('object'):\n",
" bndbox = member.find('bndbox')\n",
" bbox_width = int(bndbox.find('xmax').text) - int(bndbox.find('xmin').text)\n",
" bbox_height = int(bndbox.find('ymax').text) - int(bndbox.find('ymin').text)\n",
" if rescale_width and rescale_height:\n",
" size = root.find('size')\n",
" bbox_width = bbox_width * (rescale_width / int(size.find('width').text))\n",
" bbox_height = bbox_height * (rescale_height / int(size.find('height').text))\n",
" xml_list.append([bbox_width, bbox_height])\n",
" bboxes = np.array(xml_list)\n",
" return bboxes\n",
"\n",
......@@ -275,10 +284,10 @@
" assert len(bboxes), \"You must provide bounding boxes\"\n",
"\n",
" normalized_bboxes = bboxes / np.sqrt(bboxes.prod(axis=1, keepdims=True))\n",
"\n",
" # Using kmeans to find centroids of the width/height clusters\n",
" \n",
" # Using kmeans to find centroids of the width/height clusters\n",
" kmeans = KMeans(\n",
" init='random', n_clusters=num_aspect_ratios,random_state=0, max_iter=kmeans_max_iter)\n",
" init='random', n_clusters=num_aspect_ratios, random_state=0, max_iter=kmeans_max_iter)\n",
" kmeans.fit(X=normalized_bboxes)\n",
" ar = kmeans.cluster_centers_\n",
"\n",
......@@ -292,9 +301,7 @@
" aspect_ratios = [w/h for w,h in ar]\n",
"\n",
" return aspect_ratios, avg_iou_perc"
],
"execution_count": null,
"outputs": []
]
},
{
"cell_type": "markdown",
......@@ -323,13 +330,12 @@
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "cNw-vX3nfl1g"
},
"outputs": [],
"source": [
"classes = ['Abyssinian','american_bulldog']\n",
"xml_path = '/content/dataset/annotations/xmls'\n",
"\n",
"# Tune this based on your accuracy/speed goals as described above\n",
"num_aspect_ratios = 4 # can be [2,3,4,5,6]\n",
"\n",
......@@ -342,8 +348,7 @@
"height = 320\n",
"\n",
"# Get the ground-truth bounding boxes for our dataset\n",
"bboxes = xml_to_boxes(path=xml_path, classes=classes,\n",
" rescale_width=width, rescale_height=height)\n",
"bboxes = xml_to_boxes(path=XML_PATH, rescale_width=width, rescale_height=height)\n",
"\n",
"aspect_ratios, avg_iou_perc = kmeans_aspect_ratios(\n",
" bboxes=bboxes,\n",
......@@ -354,9 +359,7 @@
"\n",
"print('Aspect ratios generated:', [round(ar,2) for ar in aspect_ratios])\n",
"print('Average IOU with anchors:', avg_iou_perc)"
],
"execution_count": null,
"outputs": []
]
},
{
"cell_type": "markdown",
......@@ -378,9 +381,11 @@
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "AlMffd3rgKW2"
},
"outputs": [],
"source": [
"import tensorflow as tf\n",
"from google.protobuf import text_format\n",
......@@ -404,9 +409,7 @@
" f.write(config_text)\n",
"# Check for updated aspect ratios in the config\n",
"!cat /content/ssdlite_mobiledet_edgetpu_320x320_custom_aspect_ratios.config"
],
"execution_count": null,
"outputs": []
]
},
{
"cell_type": "markdown",
......@@ -441,5 +444,22 @@
"\n"
]
}
]
}
\ No newline at end of file
],
"metadata": {
"colab": {
"collapsed_sections": [],
"name": "Generate_SSD_anchor_box_aspect_ratios_using_k_means_clustering.ipynb",
"provenance": [],
"toc_visible": true
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
},
"language_info": {
"name": "python"
}
},
"nbformat": 4,
"nbformat_minor": 0
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment