{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": { "colab": {}, "colab_type": "code", "id": "R8_IiKZZEU9i" }, "outputs": [], "source": [ "import torch\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", "from kmeans_pytorch import kmeans, kmeans_predict" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "colab": {}, "colab_type": "code", "id": "iyoljrh1FCxJ" }, "outputs": [], "source": [ "# set random seed\n", "np.random.seed(123)" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "colab": {}, "colab_type": "code", "id": "tcnoTA16FHbJ" }, "outputs": [], "source": [ "# data\n", "data_size, dims, num_clusters = 1000, 2, 3\n", "x = np.random.randn(data_size, dims) / 6\n", "x = torch.from_numpy(x)" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "colab": {}, "colab_type": "code", "id": "Ar-lcW3OFTXI" }, "outputs": [], "source": [ "# set device\n", "if torch.cuda.is_available():\n", " device = torch.device('cuda:0')\n", "else:\n", " device = torch.device('cpu')" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 71 }, "colab_type": "code", "id": "KsM9zQZ5FYKp", "outputId": "c37d7629-560f-4191-bbee-2f5523d2cda2" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "running k-means on cuda:0..\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "[running kmeans]: 6it [00:00, 73.67it/s, center_shift=0.000068, iteration=7, tol=0.000100]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "device is :cuda:0\n", "device is :cuda:0\n", "device is :cuda:0\n", "device is :cuda:0\n", "device is :cuda:0\n", "device is :cuda:0\n", "device is :cuda:0\n" ] } ], "source": [ "# k-means\n", "cluster_ids_x, cluster_centers = kmeans(\n", " X=x, num_clusters=num_clusters, distance='soft_dtw', device=device, gamma_for_soft_dtw=0.0001\n", ")" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 797 }, "colab_type": "code", "id": "IdzkYHBEFdja", "outputId": "3bd48cc3-487c-40f7-826a-c8fccecfce89" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tensor([2, 0, 2, 0, 1, 0, 1, 0, 1, 1, 2, 2, 0, 1, 0, 0, 0, 1, 2, 2, 0, 2, 1, 1,\n", " 2, 0, 1, 2, 2, 1, 2, 0, 1, 1, 2, 1, 1, 2, 0, 0, 1, 1, 0, 0, 1, 1, 2, 2,\n", " 0, 1, 0, 2, 1, 0, 0, 2, 2, 1, 0, 1, 0, 2, 1, 1, 1, 0, 2, 1, 2, 1, 2, 1,\n", " 1, 2, 2, 1, 0, 2, 1, 1, 1, 2, 1, 1, 1, 0, 2, 2, 1, 2, 2, 1, 0, 0, 2, 1,\n", " 1, 0, 0, 0, 1, 1, 1, 0, 2, 1, 0, 2, 1, 2, 0, 0, 1, 0, 2, 2, 2, 1, 1, 1,\n", " 1, 0, 1, 0, 2, 1, 0, 1, 1, 2, 1, 0, 1, 1, 1, 0, 1, 0, 1, 0, 0, 0, 0, 2,\n", " 2, 2, 1, 2, 1, 2, 1, 2, 2, 2, 1, 1, 2, 1, 0, 2, 0, 0, 1, 0, 2, 2, 0, 0,\n", " 2, 1, 0, 1, 0, 2, 2, 0, 0, 0, 2, 0, 2, 2, 2, 1, 1, 0, 1, 2, 2, 0, 1, 0,\n", " 2, 2, 1, 1, 0, 0, 2, 2, 1, 0, 2, 0, 2, 1, 2, 1, 1, 0, 2, 0, 0, 2, 2, 2,\n", " 0, 1, 0, 1, 1, 2, 1, 2, 1, 0, 0, 2, 2, 2, 2, 0, 1, 1, 1, 2, 1, 0, 2, 0,\n", " 0, 2, 2, 1, 1, 0, 0, 2, 1, 1, 1, 2, 1, 0, 0, 1, 1, 2, 2, 1, 0, 0, 2, 1,\n", " 1, 0, 1, 2, 1, 2, 0, 2, 2, 0, 2, 1, 0, 1, 1, 1, 2, 0, 1, 2, 2, 1, 1, 1,\n", " 0, 1, 0, 1, 2, 0, 2, 1, 2, 1, 0, 1, 1, 1, 1, 1, 1, 2, 1, 1, 2, 0, 2, 1,\n", " 0, 0, 2, 0, 2, 0, 1, 2, 1, 2, 0, 0, 2, 1, 1, 1, 1, 0, 2, 0, 2, 2, 1, 0,\n", " 1, 2, 2, 1, 1, 1, 2, 2, 0, 0, 1, 2, 1, 1, 0, 1, 2, 1, 2, 0, 0, 2, 0, 1,\n", " 1, 1, 2, 2, 1, 2, 0, 2, 0, 0, 2, 0, 2, 1, 2, 1, 1, 2, 2, 0, 1, 0, 0, 0,\n", " 0, 1, 0, 2, 2, 1, 0, 2, 0, 0, 2, 2, 2, 0, 1, 2, 0, 2, 2, 1, 2, 1, 2, 1,\n", " 0, 0, 0, 2, 0, 2, 2, 2, 0, 1, 1, 0, 2, 2, 0, 2, 2, 1, 0, 0, 2, 2, 0, 0,\n", " 1, 0, 1, 2, 0, 2, 0, 1, 0, 0, 0, 1, 2, 2, 1, 1, 2, 1, 1, 1, 0, 0, 2, 0,\n", " 0, 0, 2, 1, 1, 1, 2, 2, 2, 2, 0, 0, 1, 2, 0, 0, 1, 2, 1, 0, 1, 0, 2, 2,\n", " 0, 0, 0, 0, 2, 1, 0, 2, 1, 1, 2, 1, 0, 2, 0, 2, 0, 2, 1, 1, 2, 1, 0, 0,\n", " 0, 1, 2, 1, 1, 0, 2, 0, 2, 1, 2, 1, 1, 2, 0, 1, 0, 0, 2, 0, 2, 2, 2, 1,\n", " 1, 2, 1, 1, 2, 1, 1, 1, 1, 0, 0, 2, 1, 1, 2, 1, 1, 2, 0, 0, 2, 1, 2, 1,\n", " 1, 1, 1, 1, 0, 0, 2, 2, 1, 0, 1, 2, 2, 0, 1, 0, 2, 0, 2, 2, 2, 0, 1, 2,\n", " 2, 0, 1, 2, 1, 1, 2, 1, 2, 1, 0, 2, 0, 2, 1, 0, 2, 0, 1, 2, 1, 2, 1, 0,\n", " 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 0, 1, 1, 2, 0, 0, 2, 1, 0, 1, 1, 1, 0, 0,\n", " 2, 1, 0, 2, 1, 1, 0, 2, 1, 2, 0, 2, 2, 1, 1, 0, 0, 2, 0, 2, 1, 1, 0, 1,\n", " 1, 0, 2, 2, 2, 1, 0, 0, 2, 1, 1, 1, 2, 1, 0, 1, 1, 1, 2, 2, 1, 1, 2, 1,\n", " 0, 1, 0, 0, 0, 2, 0, 1, 0, 0, 1, 1, 0, 1, 2, 1, 1, 1, 1, 0, 1, 0, 0, 2,\n", " 1, 2, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 2, 2, 2, 0,\n", " 2, 2, 0, 2, 2, 1, 1, 1, 1, 0, 2, 1, 2, 2, 1, 1, 1, 1, 1, 1, 2, 1, 0, 2,\n", " 2, 0, 2, 2, 0, 2, 1, 0, 0, 2, 0, 0, 1, 0, 2, 2, 0, 1, 2, 0, 0, 1, 1, 2,\n", " 2, 2, 0, 1, 2, 0, 0, 1, 2, 2, 0, 1, 0, 0, 2, 2, 0, 2, 1, 0, 1, 1, 2, 1,\n", " 0, 2, 1, 1, 0, 1, 1, 0, 2, 2, 2, 2, 1, 0, 0, 0, 2, 1, 2, 2, 0, 0, 0, 2,\n", " 1, 2, 1, 0, 2, 0, 0, 1, 1, 2, 2, 1, 2, 1, 2, 0, 0, 2, 1, 0, 1, 0, 0, 2,\n", " 2, 2, 2, 0, 1, 2, 2, 2, 2, 2, 1, 0, 0, 1, 1, 0, 2, 0, 2, 0, 2, 0, 0, 1,\n", " 0, 0, 0, 2, 0, 2, 1, 2, 0, 1, 0, 2, 0, 0, 0, 1, 0, 1, 1, 1, 0, 2, 2, 0,\n", " 1, 2, 0, 1, 1, 2, 2, 1, 2, 1, 0, 1, 0, 2, 1, 1, 2, 1, 1, 2, 2, 0, 1, 0,\n", " 2, 2, 0, 2, 2, 2, 1, 1, 0, 1, 2, 0, 2, 1, 0, 2, 1, 0, 1, 0, 2, 2, 2, 2,\n", " 2, 2, 1, 1, 2, 2, 2, 1, 2, 2, 1, 0, 0, 1, 1, 2, 1, 0, 1, 1, 1, 0, 2, 2,\n", " 2, 2, 1, 2, 0, 1, 2, 0, 1, 1, 1, 1, 2, 2, 0, 0, 2, 1, 1, 1, 0, 2, 0, 2,\n", " 2, 2, 0, 2, 1, 2, 2, 2, 2, 2, 2, 2, 2, 1, 0, 0])\n", "tensor([[-0.1075, -0.1522],\n", " [ 0.1544, -0.0137],\n", " [-0.0833, 0.1454]])\n" ] } ], "source": [ "# cluster IDs and cluster centers\n", "print(cluster_ids_x)\n", "print(cluster_centers)" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "colab": {}, "colab_type": "code", "id": "32XHxknWFayP" }, "outputs": [], "source": [ "# more data\n", "y = np.random.randn(5, dims) / 6\n", "y = torch.from_numpy(y)" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 34 }, "colab_type": "code", "id": "gQB3jVTKFfsN", "outputId": "1f2f634a-4a32-4b05-ca7f-231165e974ea" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "[running kmeans]: 7it [09:38, 82.70s/it, center_shift=0.000068, iteration=7, tol=0.000100]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "predicting on cuda:0..\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\n" ] } ], "source": [ "# predict cluster ids for y\n", "cluster_ids_y = kmeans_predict(\n", " y, cluster_centers, 'soft_dtw', device=device\n", ")" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 34 }, "colab_type": "code", "id": "CGoD8s6_FiDp", "outputId": "7137fa6a-f5a8-4096-f5d2-fc74d45eaf2f" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tensor([1, 2, 0, 1, 2])\n" ] } ], "source": [ "print(cluster_ids_y)" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 481 }, "colab_type": "code", "id": "5V0vxpaUEnFd", "outputId": "a97b6261-034e-493a-d8a6-cf955ab5e8fd" }, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "# plot\n", "plt.figure(figsize=(4, 3), dpi=160)\n", "plt.scatter(x[:, 0], x[:, 1], c=cluster_ids_x, cmap='cool')\n", "plt.scatter(y[:, 0], y[:, 1], c=cluster_ids_y, cmap='cool', marker='X')\n", "plt.scatter(\n", " cluster_centers[:, 0], cluster_centers[:, 1],\n", " c='white',\n", " alpha=0.6,\n", " edgecolors='black',\n", " linewidths=2\n", ")\n", "plt.axis([-1, 1, -1, 1])\n", "plt.tight_layout()\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": {}, "colab_type": "code", "id": "BHkQ-zw9EnsN" }, "outputs": [], "source": [] } ], "metadata": { "accelerator": "GPU", "colab": { "collapsed_sections": [], "name": "example.ipynb", "provenance": [] }, "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.8" } }, "nbformat": 4, "nbformat_minor": 4 }