{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Retiarii Example - One-shot NAS" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This example will show Retiarii's ability to **express** and **explore** the model space for Neural Architecture Search and Hyper-Parameter Tuning in a simple way. The video demo is in [YouTube](https://youtu.be/3nEx9GMHYEk) and [Bilibili](https://www.bilibili.com/video/BV1c54y1V7vx/).\n", "\n", "Let's start the journey with Retiarii!" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Step 1: Express the Model Space" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Model space is defined by users to express a set of models that they want to explore, which contains potentially good-performing models. In Retiarii framework, a model space is defined with two parts: a base model and possible mutations on the base model." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Step 1.1: Define the Base Model" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Defining a base model is almost the same as defining a PyTorch (or TensorFlow) model. Usually, you only need to replace the code ``import torch.nn as nn`` with ``import nni.retiarii.nn.pytorch as nn`` to use NNI wrapped PyTorch modules. Below is a very simple example of defining a base model." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import torch.nn.functional as F\n", "import nni.retiarii.nn.pytorch as nn\n", "\n", "class Net(nn.Module):\n", " def __init__(self):\n", " super(Net, self).__init__()\n", " self.conv1 = nn.Conv2d(3, 6, 3, padding=1)\n", " self.pool = nn.MaxPool2d(2, 2)\n", " self.conv2 = nn.Conv2d(6, 16, 3, padding=1)\n", " self.conv3 = nn.Conv2d(16, 16, 1)\n", "\n", " self.bn = nn.BatchNorm2d(16)\n", "\n", " self.gap = nn.AdaptiveAvgPool2d(4)\n", " self.fc1 = nn.Linear(16 * 4 * 4, 120)\n", " self.fc2 = nn.Linear(120, 84)\n", " self.fc3 = nn.Linear(84, 10)\n", "\n", " def forward(self, x):\n", " bs = x.size(0)\n", "\n", " x = self.pool(F.relu(self.conv1(x)))\n", " x0 = F.relu(self.conv2(x))\n", " x1 = F.relu(self.conv3(x0))\n", "\n", " x1 += x0\n", " x = self.pool(self.bn(x1))\n", "\n", " x = self.gap(x).view(bs, -1)\n", " x = F.relu(self.fc1(x))\n", " x = F.relu(self.fc2(x))\n", " x = self.fc3(x)\n", " return x\n", " \n", "model = Net()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Step 1.2: Define the Model Mutations" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "A base model is only one concrete model, not a model space. NNI provides APIs and primitives for users to express how the base model can be mutated, i.e., a model space that includes many models. The following will use inline Mutation APIs ``LayerChoice`` to choose a layer from candidate operations and use ``InputChoice`` to try out skip connection." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import torch.nn.functional as F\n", "import nni.retiarii.nn.pytorch as nn\n", "\n", "class Net(nn.Module):\n", " def __init__(self):\n", " super(Net, self).__init__()\n", " # self.conv1 = nn.Conv2d(3, 6, 3, padding=1)\n", " self.conv1 = nn.LayerChoice([nn.Conv2d(3, 6, 3, padding=1), nn.Conv2d(3, 6, 5, padding=2)])\n", " self.pool = nn.MaxPool2d(2, 2)\n", " # self.conv2 = nn.Conv2d(6, 16, 3, padding=1)\n", " self.conv2 = nn.LayerChoice([nn.Conv2d(6, 16, 3, padding=1), nn.Conv2d(6, 16, 5, padding=2)])\n", " self.conv3 = nn.Conv2d(16, 16, 1)\n", "\n", " self.skipconnect = nn.InputChoice(n_candidates=2)\n", " self.bn = nn.BatchNorm2d(16)\n", "\n", " self.gap = nn.AdaptiveAvgPool2d(4)\n", " self.fc1 = nn.Linear(16 * 4 * 4, 120)\n", " self.fc2 = nn.Linear(120, 84)\n", " self.fc3 = nn.Linear(84, 10)\n", "\n", " def forward(self, x):\n", " bs = x.size(0)\n", "\n", " x = self.pool(F.relu(self.conv1(x)))\n", " x0 = F.relu(self.conv2(x))\n", " x1 = F.relu(self.conv3(x0))\n", "\n", " x1 = self.skipconnect([x1, x1+x0])\n", " x = self.pool(self.bn(x1))\n", "\n", " x = self.gap(x).view(bs, -1)\n", " x = F.relu(self.fc1(x))\n", " x = F.relu(self.fc2(x))\n", " x = self.fc3(x)\n", " return x\n", " \n", "model = Net()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Step 2: Explore the Model Space" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "With a defined model space, users can explore the space in two ways. One is the multi-trial NAS method, which searchs by evaluating each sampled model independently. The other is using one-shot weight-sharing based search, which consumes much less computational resource compared to the first one. " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In this part, we focus on this **one-shot** approach. The principle of the One-shot approach is combining all the models in a model space into one big model (usually called super-model or super-graph). It takes charge of both search, training and testing, by training and evaluating this big model.\n", "\n", "Retiarii has supported some classic one-shot trainers, like DARTS trainer, ENAS trainer, ProxylessNAS trainer, Single-path trainer, and users can customize a new one-shot trainer according to the APIs provided by Retiarii conveniently.\n", "\n", "Here, we show an example to use DARTS trainer manually." ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Files already downloaded and verified\n", "[2021-06-07 11:12:22] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [1/2] Step [1/391] acc1 0.093750 (0.093750) loss 2.286068 (2.286068)\n", "[2021-06-07 11:12:22] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [1/2] Step [11/391] acc1 0.093750 (0.089489) loss 2.328799 (2.309416)\n", "[2021-06-07 11:12:23] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [1/2] Step [21/391] acc1 0.093750 (0.092262) loss 2.302527 (2.309082)\n", "[2021-06-07 11:12:23] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [1/2] Step [31/391] acc1 0.109375 (0.099294) loss 2.294730 (2.304962)\n", "[2021-06-07 11:12:23] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [1/2] Step [41/391] acc1 0.203125 (0.103277) loss 2.284227 (2.302716)\n", "[2021-06-07 11:12:23] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [1/2] Step [51/391] acc1 0.078125 (0.106618) loss 2.308704 (2.300639)\n", "[2021-06-07 11:12:23] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [1/2] Step [61/391] acc1 0.203125 (0.110143) loss 2.258595 (2.298042)\n", "[2021-06-07 11:12:23] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [1/2] Step [71/391] acc1 0.078125 (0.112896) loss 2.276706 (2.294709)\n", "[2021-06-07 11:12:24] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [1/2] Step [81/391] acc1 0.078125 (0.116898) loss 2.309119 (2.292235)\n", "[2021-06-07 11:12:24] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [1/2] Step [91/391] acc1 0.093750 (0.118304) loss 2.263757 (2.289659)\n", "[2021-06-07 11:12:24] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [1/2] Step [101/391] acc1 0.109375 (0.119431) loss 2.260739 (2.287132)\n", "[2021-06-07 11:12:24] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [1/2] Step [111/391] acc1 0.109375 (0.121481) loss 2.279930 (2.284314)\n", "[2021-06-07 11:12:24] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [1/2] Step [121/391] acc1 0.046875 (0.122934) loss 2.270205 (2.281701)\n", "[2021-06-07 11:12:25] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [1/2] Step [131/391] acc1 0.156250 (0.125477) loss 2.270163 (2.278612)\n", "[2021-06-07 11:12:25] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [1/2] Step [141/391] acc1 0.171875 (0.126551) loss 2.233467 (2.276326)\n", "[2021-06-07 11:12:25] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [1/2] Step [151/391] acc1 0.109375 (0.127897) loss 2.264694 (2.274296)\n", "[2021-06-07 11:12:25] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [1/2] Step [161/391] acc1 0.250000 (0.132279) loss 2.259590 (2.271723)\n", "[2021-06-07 11:12:25] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [1/2] Step [171/391] acc1 0.093750 (0.134868) loss 2.240986 (2.269037)\n", "[2021-06-07 11:12:25] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [1/2] Step [181/391] acc1 0.218750 (0.137690) loss 2.218153 (2.266567)\n", "[2021-06-07 11:12:25] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [1/2] Step [191/391] acc1 0.078125 (0.140134) loss 2.260816 (2.264373)\n", "[2021-06-07 11:12:26] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [1/2] Step [201/391] acc1 0.156250 (0.144123) loss 2.191213 (2.261285)\n", "[2021-06-07 11:12:26] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [1/2] Step [211/391] acc1 0.125000 (0.146919) loss 2.245425 (2.258747)\n", "[2021-06-07 11:12:26] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [1/2] Step [221/391] acc1 0.218750 (0.150028) loss 2.216708 (2.255553)\n", "[2021-06-07 11:12:26] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [1/2] Step [231/391] acc1 0.250000 (0.153003) loss 2.195549 (2.252894)\n", "[2021-06-07 11:12:26] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [1/2] Step [241/391] acc1 0.234375 (0.155666) loss 2.169693 (2.249465)\n", "[2021-06-07 11:12:26] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [1/2] Step [251/391] acc1 0.218750 (0.158989) loss 2.174878 (2.246355)\n", "[2021-06-07 11:12:27] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [1/2] Step [261/391] acc1 0.312500 (0.162775) loss 2.117693 (2.243113)\n", "[2021-06-07 11:12:27] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [1/2] Step [271/391] acc1 0.265625 (0.166686) loss 2.136203 (2.239288)\n", "[2021-06-07 11:12:27] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [1/2] Step [281/391] acc1 0.234375 (0.169095) loss 2.213463 (2.236377)\n", "[2021-06-07 11:12:27] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [1/2] Step [291/391] acc1 0.218750 (0.171338) loss 2.114096 (2.232892)\n", "[2021-06-07 11:12:27] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [1/2] Step [301/391] acc1 0.203125 (0.173432) loss 2.134074 (2.229637)\n", "[2021-06-07 11:12:28] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [1/2] Step [311/391] acc1 0.265625 (0.175291) loss 2.041354 (2.225920)\n", "[2021-06-07 11:12:28] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [1/2] Step [321/391] acc1 0.250000 (0.176840) loss 2.081122 (2.222280)\n", "[2021-06-07 11:12:28] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [1/2] Step [331/391] acc1 0.140625 (0.178578) loss 2.124206 (2.219168)\n", "[2021-06-07 11:12:28] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [1/2] Step [341/391] acc1 0.250000 (0.180169) loss 2.077291 (2.215540)\n", "[2021-06-07 11:12:28] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [1/2] Step [351/391] acc1 0.250000 (0.182381) loss 2.077531 (2.211650)\n", "[2021-06-07 11:12:28] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [1/2] Step [361/391] acc1 0.312500 (0.185033) loss 2.016619 (2.207455)\n", "[2021-06-07 11:12:29] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [1/2] Step [371/391] acc1 0.250000 (0.187163) loss 2.139604 (2.202785)\n", "[2021-06-07 11:12:29] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [1/2] Step [381/391] acc1 0.281250 (0.189099) loss 2.033739 (2.198564)\n", "[2021-06-07 11:12:29] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [1/2] Step [391/391] acc1 0.275000 (0.190441) loss 1.988353 (2.194509)\n", "[2021-06-07 11:12:29] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [2/2] Step [1/391] acc1 0.296875 (0.296875) loss 2.083627 (2.083627)\n", "[2021-06-07 11:12:30] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [2/2] Step [11/391] acc1 0.265625 (0.251420) loss 2.042856 (2.050898)\n", "[2021-06-07 11:12:30] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [2/2] Step [21/391] acc1 0.234375 (0.273065) loss 2.005307 (2.021047)\n", "[2021-06-07 11:12:30] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [2/2] Step [31/391] acc1 0.375000 (0.269657) loss 1.934093 (2.014375)\n", "[2021-06-07 11:12:30] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [2/2] Step [41/391] acc1 0.265625 (0.277439) loss 2.007705 (2.003260)\n", "[2021-06-07 11:12:30] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [2/2] Step [51/391] acc1 0.218750 (0.278799) loss 2.014602 (2.001039)\n", "[2021-06-07 11:12:31] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [2/2] Step [61/391] acc1 0.187500 (0.278945) loss 2.088407 (1.995837)\n", "[2021-06-07 11:12:31] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [2/2] Step [71/391] acc1 0.343750 (0.285651) loss 1.894479 (1.988130)\n", "[2021-06-07 11:12:31] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [2/2] Step [81/391] acc1 0.281250 (0.289159) loss 1.869002 (1.979012)\n", "[2021-06-07 11:12:31] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [2/2] Step [91/391] acc1 0.265625 (0.291552) loss 1.848354 (1.971483)\n", "[2021-06-07 11:12:31] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [2/2] Step [101/391] acc1 0.406250 (0.290996) loss 1.840711 (1.964297)\n", "[2021-06-07 11:12:31] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [2/2] Step [111/391] acc1 0.390625 (0.294764) loss 1.905811 (1.958954)\n", "[2021-06-07 11:12:32] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [2/2] Step [121/391] acc1 0.250000 (0.296617) loss 1.935214 (1.952315)\n", "[2021-06-07 11:12:32] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [2/2] Step [131/391] acc1 0.281250 (0.299618) loss 1.901846 (1.944634)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "[2021-06-07 11:12:32] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [2/2] Step [141/391] acc1 0.312500 (0.302970) loss 1.854658 (1.939751)\n", "[2021-06-07 11:12:32] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [2/2] Step [151/391] acc1 0.218750 (0.305257) loss 1.927818 (1.934704)\n", "[2021-06-07 11:12:32] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [2/2] Step [161/391] acc1 0.343750 (0.307648) loss 1.820810 (1.927533)\n", "[2021-06-07 11:12:33] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [2/2] Step [171/391] acc1 0.312500 (0.307383) loss 1.800313 (1.924665)\n", "[2021-06-07 11:12:33] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [2/2] Step [181/391] acc1 0.484375 (0.307925) loss 1.637479 (1.920402)\n", "[2021-06-07 11:12:33] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [2/2] Step [191/391] acc1 0.359375 (0.306692) loss 1.732374 (1.917680)\n", "[2021-06-07 11:12:33] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [2/2] Step [201/391] acc1 0.406250 (0.309624) loss 1.870701 (1.911484)\n", "[2021-06-07 11:12:33] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [2/2] Step [211/391] acc1 0.328125 (0.311982) loss 1.785704 (1.905039)\n", "[2021-06-07 11:12:33] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [2/2] Step [221/391] acc1 0.265625 (0.312712) loss 1.738683 (1.901547)\n", "[2021-06-07 11:12:33] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [2/2] Step [231/391] acc1 0.359375 (0.315409) loss 1.827117 (1.894860)\n", "[2021-06-07 11:12:34] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [2/2] Step [241/391] acc1 0.375000 (0.317881) loss 1.717454 (1.888916)\n", "[2021-06-07 11:12:34] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [2/2] Step [251/391] acc1 0.328125 (0.318663) loss 1.873310 (1.886883)\n", "[2021-06-07 11:12:34] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [2/2] Step [261/391] acc1 0.390625 (0.320163) loss 1.657088 (1.881767)\n", "[2021-06-07 11:12:34] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [2/2] Step [271/391] acc1 0.421875 (0.321264) loss 1.710897 (1.877521)\n", "[2021-06-07 11:12:34] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [2/2] Step [281/391] acc1 0.421875 (0.321230) loss 1.760745 (1.875136)\n", "[2021-06-07 11:12:34] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [2/2] Step [291/391] acc1 0.375000 (0.321413) loss 1.669255 (1.872129)\n", "[2021-06-07 11:12:34] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [2/2] Step [301/391] acc1 0.328125 (0.322051) loss 1.728873 (1.868047)\n", "[2021-06-07 11:12:35] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [2/2] Step [311/391] acc1 0.375000 (0.323000) loss 1.754761 (1.864783)\n", "[2021-06-07 11:12:35] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [2/2] Step [321/391] acc1 0.437500 (0.324864) loss 1.666240 (1.859164)\n", "[2021-06-07 11:12:35] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [2/2] Step [331/391] acc1 0.421875 (0.325954) loss 1.661471 (1.856318)\n", "[2021-06-07 11:12:35] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [2/2] Step [341/391] acc1 0.328125 (0.326475) loss 1.737106 (1.853075)\n", "[2021-06-07 11:12:35] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [2/2] Step [351/391] acc1 0.343750 (0.327724) loss 1.789253 (1.849491)\n", "[2021-06-07 11:12:36] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [2/2] Step [361/391] acc1 0.250000 (0.328558) loss 1.773805 (1.846033)\n", "[2021-06-07 11:12:36] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [2/2] Step [371/391] acc1 0.312500 (0.329094) loss 1.901358 (1.844091)\n", "[2021-06-07 11:12:36] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [2/2] Step [381/391] acc1 0.250000 (0.330011) loss 1.863921 (1.841390)\n", "[2021-06-07 11:12:36] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [2/2] Step [391/391] acc1 0.325000 (0.331514) loss 1.729926 (1.837162)\n" ] } ], "source": [ "import torch\n", "from utils import accuracy\n", "from torchvision import transforms\n", "from torchvision.datasets import CIFAR10\n", "from nni.retiarii.oneshot.pytorch import DartsTrainer\n", "\n", "criterion = torch.nn.CrossEntropyLoss()\n", "optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)\n", "\n", "transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])\n", "train_dataset = CIFAR10(root=\"./data\", train=True, download=True, transform=transform)\n", "\n", "trainer = DartsTrainer(\n", " model=model,\n", " loss=criterion,\n", " metrics=lambda output, target: accuracy(output, target),\n", " optimizer=optimizer,\n", " num_epochs=2,\n", " dataset=train_dataset,\n", " batch_size=64,\n", " log_frequency=10\n", " )\n", "\n", "trainer.fit()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Similarly, the optimal structure found can be exported." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Final architecture: {'_mutation_1': 1, '_mutation_2': 1, '_mutation_3': [1]}\n" ] } ], "source": [ "print('Final architecture:', trainer.export())" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.5" } }, "nbformat": 4, "nbformat_minor": 4 }