{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# 1. Prepare model" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import torch\n", "import torch.nn.functional as F\n", "\n", "class NaiveModel(torch.nn.Module):\n", " def __init__(self):\n", " super().__init__()\n", " self.conv1 = torch.nn.Conv2d(1, 20, 5, 1)\n", " self.conv2 = torch.nn.Conv2d(20, 50, 5, 1)\n", " self.fc1 = torch.nn.Linear(4 * 4 * 50, 500)\n", " self.fc2 = torch.nn.Linear(500, 10)\n", " self.relu1 = torch.nn.ReLU6()\n", " self.relu2 = torch.nn.ReLU6()\n", " self.relu3 = torch.nn.ReLU6()\n", " self.max_pool1 = torch.nn.MaxPool2d(2, 2)\n", " self.max_pool2 = torch.nn.MaxPool2d(2, 2)\n", "\n", " def forward(self, x):\n", " x = self.relu1(self.conv1(x))\n", " x = self.max_pool1(x)\n", " x = self.relu2(self.conv2(x))\n", " x = self.max_pool2(x)\n", " x = x.view(-1, x.size()[1:].numel())\n", " x = self.relu3(self.fc1(x))\n", " x = self.fc2(x)\n", " return F.log_softmax(x, dim=1)" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "# define model, optimizer, criterion, data_loader, trainer, evaluator.\n", "\n", "import torch.optim as optim\n", "from torchvision import datasets, transforms\n", "from torch.optim.lr_scheduler import StepLR\n", "\n", "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", "\n", "model = NaiveModel().to(device)\n", "\n", "optimizer = optim.Adadelta(model.parameters(), lr=1)\n", "\n", "criterion = torch.nn.NLLLoss()\n", "\n", "transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])\n", "train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)\n", "test_dataset = datasets.MNIST('./data', train=False, transform=transform)\n", "train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64)\n", "test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1000)\n", "\n", "def trainer(model, optimizer, criterion, epoch):\n", " model.train()\n", " for batch_idx, (data, target) in enumerate(train_loader):\n", " data, target = data.to(device), target.to(device)\n", " optimizer.zero_grad()\n", " output = model(data)\n", " loss = criterion(output, target)\n", " loss.backward()\n", " optimizer.step()\n", " if batch_idx % 100 == 0:\n", " print('Train Epoch: {} [{}/{} ({:.0f}%)]\\tLoss: {:.6f}'.format(\n", " epoch, batch_idx * len(data), len(train_loader.dataset),\n", " 100. * batch_idx / len(train_loader), loss.item()))\n", "\n", "def evaluator(model):\n", " model.eval()\n", " test_loss = 0\n", " correct = 0\n", " with torch.no_grad():\n", " for data, target in test_loader:\n", " data, target = data.to(device), target.to(device)\n", " output = model(data)\n", " test_loss += F.nll_loss(output, target, reduction='sum').item()\n", " pred = output.argmax(dim=1, keepdim=True)\n", " correct += pred.eq(target.view_as(pred)).sum().item()\n", "\n", " test_loss /= len(test_loader.dataset)\n", " acc = 100 * correct / len(test_loader.dataset)\n", "\n", " print('\\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\\n'.format(\n", " test_loss, correct, len(test_loader.dataset), acc))\n", "\n", " return acc" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Train Epoch: 0 [0/60000 (0%)]\tLoss: 2.313423\n", "Train Epoch: 0 [6400/60000 (11%)]\tLoss: 0.091786\n", "Train Epoch: 0 [12800/60000 (21%)]\tLoss: 0.087317\n", "Train Epoch: 0 [19200/60000 (32%)]\tLoss: 0.036397\n", "Train Epoch: 0 [25600/60000 (43%)]\tLoss: 0.008173\n", "Train Epoch: 0 [32000/60000 (53%)]\tLoss: 0.047565\n", "Train Epoch: 0 [38400/60000 (64%)]\tLoss: 0.122448\n", "Train Epoch: 0 [44800/60000 (75%)]\tLoss: 0.036732\n", "Train Epoch: 0 [51200/60000 (85%)]\tLoss: 0.150135\n", "Train Epoch: 0 [57600/60000 (96%)]\tLoss: 0.109684\n", "\n", "Test set: Average loss: 0.0457, Accuracy: 9857/10000 (99%)\n", "\n", "Train Epoch: 1 [0/60000 (0%)]\tLoss: 0.020650\n", "Train Epoch: 1 [6400/60000 (11%)]\tLoss: 0.091525\n", "Train Epoch: 1 [12800/60000 (21%)]\tLoss: 0.019602\n", "Train Epoch: 1 [19200/60000 (32%)]\tLoss: 0.027827\n", "Train Epoch: 1 [25600/60000 (43%)]\tLoss: 0.019414\n", "Train Epoch: 1 [32000/60000 (53%)]\tLoss: 0.007640\n", "Train Epoch: 1 [38400/60000 (64%)]\tLoss: 0.051296\n", "Train Epoch: 1 [44800/60000 (75%)]\tLoss: 0.012038\n", "Train Epoch: 1 [51200/60000 (85%)]\tLoss: 0.121057\n", "Train Epoch: 1 [57600/60000 (96%)]\tLoss: 0.015796\n", "\n", "Test set: Average loss: 0.0302, Accuracy: 9902/10000 (99%)\n", "\n", "Train Epoch: 2 [0/60000 (0%)]\tLoss: 0.009903\n", "Train Epoch: 2 [6400/60000 (11%)]\tLoss: 0.062256\n", "Train Epoch: 2 [12800/60000 (21%)]\tLoss: 0.013844\n", "Train Epoch: 2 [19200/60000 (32%)]\tLoss: 0.014133\n", "Train Epoch: 2 [25600/60000 (43%)]\tLoss: 0.001051\n", "Train Epoch: 2 [32000/60000 (53%)]\tLoss: 0.006128\n", "Train Epoch: 2 [38400/60000 (64%)]\tLoss: 0.032162\n", "Train Epoch: 2 [44800/60000 (75%)]\tLoss: 0.007687\n", "Train Epoch: 2 [51200/60000 (85%)]\tLoss: 0.092295\n", "Train Epoch: 2 [57600/60000 (96%)]\tLoss: 0.006266\n", "\n", "Test set: Average loss: 0.0259, Accuracy: 9920/10000 (99%)\n", "\n" ] } ], "source": [ "# pre-train model for 3 epoches.\n", "\n", "scheduler = StepLR(optimizer, step_size=1, gamma=0.7)\n", "\n", "for epoch in range(0, 3):\n", " trainer(model, optimizer, criterion, epoch)\n", " evaluator(model)\n", " scheduler.step()" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "op_name: \n", "op_type: \n", "\n", "op_name: conv1\n", "op_type: \n", "\n", "op_name: conv2\n", "op_type: \n", "\n", "op_name: fc1\n", "op_type: \n", "\n", "op_name: fc2\n", "op_type: \n", "\n", "op_name: relu1\n", "op_type: \n", "\n", "op_name: relu2\n", "op_type: \n", "\n", "op_name: relu3\n", "op_type: \n", "\n", "op_name: max_pool1\n", "op_type: \n", "\n", "op_name: max_pool2\n", "op_type: \n", "\n" ] }, { "data": { "text/plain": [ "[None, None, None, None, None, None, None, None, None, None]" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# show all op_name and op_type in the model.\n", "\n", "[print('op_name: {}\\nop_type: {}\\n'.format(name, type(module))) for name, module in model.named_modules()]" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "torch.Size([20, 1, 5, 5])\n" ] } ], "source": [ "# show the weight size of `conv1`.\n", "\n", "print(model.conv1.weight.data.size())" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tensor([[[[ 1.5338e-01, -1.1766e-01, -2.6654e-01, -2.9445e-02, -1.4650e-01],\n", " [-1.8796e-01, -2.9882e-01, 6.9725e-02, 2.1561e-01, 6.5688e-02],\n", " [ 1.5274e-01, -9.8471e-03, 3.2303e-01, 1.3472e-03, 1.7235e-01],\n", " [ 1.1804e-01, 2.2535e-01, -8.3370e-02, -3.4553e-02, -1.2529e-01],\n", " [-6.6012e-02, -2.0272e-02, -1.8797e-01, -4.6882e-02, -8.3206e-02]]],\n", "\n", "\n", " [[[-1.2112e-01, 7.0756e-02, 5.0446e-02, 1.5156e-01, -2.7929e-02],\n", " [-1.9744e-01, -2.1336e-03, 7.2534e-02, 6.2336e-02, 1.6039e-01],\n", " [-6.7510e-02, 1.4636e-01, 7.1972e-02, -8.9118e-02, -4.0895e-02],\n", " [ 2.9499e-02, 2.0788e-01, -1.4989e-01, 1.1668e-01, -2.8503e-01],\n", " [ 8.1894e-02, -1.4489e-01, -4.2038e-02, -1.2794e-01, -5.0379e-02]]],\n", "\n", "\n", " [[[ 3.8332e-02, -1.4270e-01, -1.9585e-01, 2.2653e-01, 1.0104e-01],\n", " [-2.7956e-03, -1.4108e-01, -1.4694e-01, -1.3525e-01, 2.6959e-01],\n", " [ 1.9522e-01, -1.2281e-01, -1.9173e-01, -1.8910e-02, 3.1572e-03],\n", " [-1.0580e-01, -2.5239e-02, -5.8266e-02, -6.5815e-02, 6.6433e-02],\n", " [ 8.9601e-02, 7.1189e-02, -2.4255e-01, 1.5746e-01, -1.4708e-01]]],\n", "\n", "\n", " [[[-1.1963e-01, -1.7243e-01, -3.5174e-02, 1.4651e-01, -1.1675e-01],\n", " [-1.3518e-01, 1.2830e-02, 7.7188e-02, 2.1060e-01, 4.0924e-02],\n", " [-4.3364e-02, -1.9579e-01, -3.6559e-02, -6.9803e-02, 1.2380e-01],\n", " [ 7.7321e-02, 3.7590e-02, 8.2935e-02, 2.2878e-01, 2.7859e-03],\n", " [-1.3601e-01, -2.1167e-01, -2.3195e-01, -1.2524e-01, 1.0073e-01]]],\n", "\n", "\n", " [[[-2.7300e-01, 6.8470e-02, 2.8405e-02, -4.5879e-03, -1.3735e-01],\n", " [-8.9789e-02, -2.0209e-03, 5.0950e-03, 2.1633e-01, 2.5554e-01],\n", " [ 5.4389e-02, 1.2262e-01, -1.5514e-01, -1.0416e-01, 1.3606e-01],\n", " [-1.6794e-01, -2.8876e-02, 2.5900e-02, -2.4261e-02, 1.0923e-01],\n", " [ 5.2524e-03, -4.4625e-02, -2.1327e-01, -1.7211e-01, -4.4819e-04]]],\n", "\n", "\n", " [[[ 7.2378e-02, 1.5122e-01, -1.2964e-01, 4.9105e-02, -2.1639e-01],\n", " [ 3.6547e-02, -1.5518e-02, 3.2059e-02, -3.2820e-02, 6.1231e-02],\n", " [ 1.2514e-01, 8.0623e-02, 1.2686e-02, -1.0074e-01, 2.2836e-02],\n", " [-2.6842e-02, 2.5578e-02, -2.5877e-01, -1.7808e-01, 7.6966e-02],\n", " [-4.2424e-02, 4.7006e-02, -1.5486e-02, -4.2686e-02, 4.8482e-02]]],\n", "\n", "\n", " [[[ 1.3081e-01, 9.9530e-02, -1.4729e-01, -1.7665e-01, -1.9757e-01],\n", " [ 9.6603e-02, 2.2783e-02, 7.8402e-02, -2.8679e-02, 8.5252e-02],\n", " [-1.5310e-02, 1.1605e-01, -5.8300e-02, 2.4563e-02, 1.7488e-01],\n", " [ 6.5576e-02, -1.6325e-01, -1.1318e-01, -2.9251e-02, 6.2352e-02],\n", " [-1.9084e-03, -1.4005e-01, -1.2363e-01, -9.7985e-02, -2.0562e-01]]],\n", "\n", "\n", " [[[ 4.0772e-02, -8.2086e-02, -2.7555e-01, -3.2547e-01, -1.2226e-01],\n", " [-5.9877e-02, 9.8567e-02, 2.5186e-01, -1.0280e-01, -2.3416e-01],\n", " [ 8.5760e-02, 1.0896e-01, 1.4898e-01, 2.1579e-01, 8.5297e-02],\n", " [ 5.4720e-02, -1.7226e-01, -7.2518e-02, 6.7099e-03, -1.6011e-03],\n", " [-8.9944e-02, 1.7404e-01, -3.6985e-02, 1.8602e-01, 7.2353e-02]]],\n", "\n", "\n", " [[[ 1.6276e-02, -9.6439e-02, -9.6085e-02, -2.4267e-01, -1.8521e-01],\n", " [ 6.3310e-02, 1.7866e-01, 1.1694e-01, -1.4464e-01, -2.7711e-01],\n", " [-2.4514e-02, 2.2222e-01, 2.1053e-01, -1.4271e-01, 8.7045e-02],\n", " [-1.9207e-01, -5.4719e-02, -5.7775e-03, -1.0034e-05, -1.0923e-01],\n", " [-2.4006e-02, 2.3780e-02, 1.8988e-01, 2.4734e-01, 4.8097e-02]]],\n", "\n", "\n", " [[[ 1.1335e-01, -5.8451e-02, 5.2440e-02, -1.3223e-01, -2.5534e-02],\n", " [ 9.1323e-02, -6.0707e-02, 2.3524e-01, 2.4992e-01, 8.7842e-02],\n", " [ 2.9002e-02, 3.5379e-02, -5.9689e-02, -2.8363e-03, 1.8618e-01],\n", " [-2.9671e-01, 8.1830e-03, 1.1076e-01, -5.4118e-02, -6.1685e-02],\n", " [-1.7580e-01, -3.4534e-01, -3.9250e-01, -2.7569e-01, -2.6131e-01]]],\n", "\n", "\n", " [[[ 1.1586e-01, -7.5997e-02, -1.4614e-01, 4.8750e-02, 1.8097e-01],\n", " [-6.7027e-02, -1.4901e-01, -1.5614e-02, -1.0379e-02, 9.5526e-02],\n", " [-3.2333e-02, -1.5107e-01, -1.9498e-01, 1.0083e-01, 2.2328e-01],\n", " [-2.0692e-01, -6.3798e-02, -1.2524e-01, 1.9549e-01, 1.9682e-01],\n", " [-2.1494e-01, 1.0475e-01, -2.4858e-02, -9.7831e-02, 1.1551e-01]]],\n", "\n", "\n", " [[[ 6.3785e-02, -1.8044e-01, -1.0190e-01, -1.3588e-01, 8.5433e-02],\n", " [ 2.0675e-01, 3.3238e-02, 9.2437e-02, 1.1799e-01, 2.1111e-01],\n", " [-5.2138e-02, 1.5790e-01, 1.8151e-01, 8.0470e-02, 1.0131e-01],\n", " [-4.4786e-02, 1.1771e-01, 2.1706e-02, -1.2563e-01, -2.1142e-01],\n", " [-2.3589e-01, -2.1154e-01, -1.7890e-01, -2.7769e-01, -1.2512e-01]]],\n", "\n", "\n", " [[[ 1.9133e-01, 2.4711e-01, 1.0413e-01, -1.9187e-01, -3.0991e-01],\n", " [-1.2382e-01, 8.3641e-03, -5.6734e-02, 5.8376e-02, 2.2880e-02],\n", " [-3.1734e-01, -1.0637e-02, -5.5974e-02, 1.0676e-01, -1.1080e-02],\n", " [-2.2980e-01, 2.0486e-01, 1.0147e-01, 1.4484e-01, 5.2265e-02],\n", " [ 7.4410e-02, 2.2806e-02, 8.5137e-02, -2.1809e-01, 3.1704e-02]]],\n", "\n", "\n", " [[[-1.1006e-01, -2.5311e-01, 1.8925e-02, 1.0399e-02, 1.1951e-01],\n", " [-2.1116e-01, 1.8409e-01, 3.2172e-02, 1.5962e-01, -7.9457e-02],\n", " [ 1.1059e-01, 9.1966e-02, 1.0777e-01, -9.9132e-02, -4.4586e-02],\n", " [-8.7919e-02, -3.7283e-02, 9.1275e-02, -3.7412e-02, 3.8875e-02],\n", " [-4.3558e-02, 1.6196e-01, -4.7944e-03, -1.7560e-02, -1.2593e-01]]],\n", "\n", "\n", " [[[ 7.6976e-02, -3.8627e-02, 1.2610e-01, 1.1994e-01, 2.1706e-03],\n", " [ 7.4357e-02, 6.7929e-02, 3.1386e-02, 1.4606e-01, 2.1429e-01],\n", " [-2.6569e-01, -4.2631e-04, -3.6654e-02, -3.0967e-02, -9.4961e-02],\n", " [-2.0192e-01, -3.5423e-01, -2.5246e-01, -3.5092e-01, -2.4159e-01],\n", " [ 1.7636e-02, 1.3744e-01, -1.0306e-01, 8.8370e-02, 7.3258e-02]]],\n", "\n", "\n", " [[[ 2.0016e-01, 1.0956e-01, -5.9223e-02, 6.4871e-03, -2.4165e-01],\n", " [ 5.6283e-02, 1.7276e-01, -2.2316e-01, -1.6699e-01, -7.0742e-02],\n", " [ 2.6179e-01, -2.5102e-01, -2.0774e-01, -9.6413e-02, 3.4367e-02],\n", " [-9.1882e-02, -2.9195e-01, -8.7432e-02, 1.0144e-01, -2.0559e-02],\n", " [-2.5668e-01, -9.8016e-02, 1.1103e-01, -3.0233e-02, 1.1076e-01]]],\n", "\n", "\n", " [[[ 1.0027e-03, -5.7955e-02, -2.1339e-01, -1.6729e-01, -2.0870e-01],\n", " [ 4.2464e-02, 2.3177e-01, -6.1459e-02, -1.0905e-01, 1.7613e-02],\n", " [-1.2282e-01, 2.1762e-01, -1.3553e-02, 2.7476e-01, 1.6703e-01],\n", " [-5.6282e-02, 1.2731e-02, 1.0944e-01, -1.7347e-01, 4.4497e-02],\n", " [ 5.7346e-02, -5.4657e-02, 4.8718e-02, -2.6221e-02, -2.6933e-02]]],\n", "\n", "\n", " [[[ 6.7697e-02, 1.5692e-01, 2.7050e-01, 1.5936e-02, 1.7659e-01],\n", " [-2.8899e-02, -1.4866e-01, 3.1838e-02, 1.0903e-01, 1.2292e-01],\n", " [-1.3608e-01, -4.3198e-03, -9.8925e-02, -4.5599e-02, 1.3452e-01],\n", " [-5.1435e-02, -2.3815e-01, -2.4151e-01, -4.8556e-02, 1.3825e-01],\n", " [-1.2823e-01, 8.9324e-03, -1.5313e-01, -2.2933e-01, -3.4081e-02]]],\n", "\n", "\n", " [[[-1.8396e-01, -6.8774e-03, -1.6675e-01, 7.1980e-03, 1.9922e-02],\n", " [ 1.3416e-01, -1.1450e-01, -1.5277e-01, -6.5713e-02, -9.5435e-02],\n", " [ 1.5406e-01, -9.1235e-02, -1.0880e-01, -7.1603e-02, -9.5575e-02],\n", " [ 2.1772e-01, 8.4073e-02, -2.5264e-01, -2.1428e-01, 1.9537e-01],\n", " [ 1.3124e-01, 7.9532e-02, -2.4044e-01, -1.5717e-01, 1.6562e-01]]],\n", "\n", "\n", " [[[ 1.1849e-01, -5.0517e-03, -1.8900e-01, 1.8093e-02, 6.4660e-02],\n", " [-1.5309e-01, -2.0106e-01, -8.6551e-02, 5.2692e-03, 1.5448e-01],\n", " [-3.0727e-01, 4.9703e-02, -4.7637e-02, 2.9111e-01, -1.3173e-01],\n", " [-8.5167e-02, -1.3540e-01, 2.9235e-01, 3.7895e-03, -9.4651e-02],\n", " [-6.0694e-02, 9.6936e-02, 1.0533e-01, -6.1769e-02, -1.8086e-01]]]],\n", " device='cuda:0')\n" ] } ], "source": [ "# show the weight of `conv1`.\n", "\n", "print(model.conv1.weight.data)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# 2. Prepare config_list for pruning" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "# we will prune 50% weights in `conv1`.\n", "\n", "config_list = [{\n", " 'sparsity': 0.5,\n", " 'op_types': ['Conv2d'],\n", " 'op_names': ['conv1']\n", "}]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# 3. Choose a pruner and pruning" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "# use l1filter pruner to prune the model\n", "\n", "from nni.algorithms.compression.pytorch.pruning import L1FilterPruner\n", "\n", "# Note that if you use a compressor that need you to pass a optimizer,\n", "# you need a new optimizer instead of you have used above, because NNI might modify the optimizer.\n", "# And of course this modified optimizer can not be used in finetuning.\n", "pruner = L1FilterPruner(model, config_list)" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "op_name: \n", "op_type: \n", "\n", "op_name: conv1\n", "op_type: \n", "\n", "op_name: conv1.module\n", "op_type: \n", "\n", "op_name: conv2\n", "op_type: \n", "\n", "op_name: fc1\n", "op_type: \n", "\n", "op_name: fc2\n", "op_type: \n", "\n", "op_name: relu1\n", "op_type: \n", "\n", "op_name: relu2\n", "op_type: \n", "\n", "op_name: relu3\n", "op_type: \n", "\n", "op_name: max_pool1\n", "op_type: \n", "\n", "op_name: max_pool2\n", "op_type: \n", "\n" ] }, { "data": { "text/plain": [ "[None, None, None, None, None, None, None, None, None, None, None]" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# we can find the `conv1` has been wrapped, the origin `conv1` changes to `conv1.module`.\n", "# the weight of conv1 will modify by `weight * mask` in `forward()`. The initial mask is a `ones_like(weight)` tensor.\n", "\n", "[print('op_name: {}\\nop_type: {}\\n'.format(name, type(module))) for name, module in model.named_modules()]" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "NaiveModel(\n", " (conv1): PrunerModuleWrapper(\n", " (module): Conv2d(1, 20, kernel_size=(5, 5), stride=(1, 1))\n", " )\n", " (conv2): Conv2d(20, 50, kernel_size=(5, 5), stride=(1, 1))\n", " (fc1): Linear(in_features=800, out_features=500, bias=True)\n", " (fc2): Linear(in_features=500, out_features=10, bias=True)\n", " (relu1): ReLU6()\n", " (relu2): ReLU6()\n", " (relu3): ReLU6()\n", " (max_pool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n", " (max_pool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n", ")" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# compress the model, the mask will be updated.\n", "\n", "pruner.compress()" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "torch.Size([20, 1, 5, 5])\n" ] } ], "source": [ "# show the mask size of `conv1`\n", "\n", "print(model.conv1.weight_mask.size())" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tensor([[[[1., 1., 1., 1., 1.],\n", " [1., 1., 1., 1., 1.],\n", " [1., 1., 1., 1., 1.],\n", " [1., 1., 1., 1., 1.],\n", " [1., 1., 1., 1., 1.]]],\n", "\n", "\n", " [[[0., 0., 0., 0., 0.],\n", " [0., 0., 0., 0., 0.],\n", " [0., 0., 0., 0., 0.],\n", " [0., 0., 0., 0., 0.],\n", " [0., 0., 0., 0., 0.]]],\n", "\n", "\n", " [[[1., 1., 1., 1., 1.],\n", " [1., 1., 1., 1., 1.],\n", " [1., 1., 1., 1., 1.],\n", " [1., 1., 1., 1., 1.],\n", " [1., 1., 1., 1., 1.]]],\n", "\n", "\n", " [[[0., 0., 0., 0., 0.],\n", " [0., 0., 0., 0., 0.],\n", " [0., 0., 0., 0., 0.],\n", " [0., 0., 0., 0., 0.],\n", " [0., 0., 0., 0., 0.]]],\n", "\n", "\n", " [[[0., 0., 0., 0., 0.],\n", " [0., 0., 0., 0., 0.],\n", " [0., 0., 0., 0., 0.],\n", " [0., 0., 0., 0., 0.],\n", " [0., 0., 0., 0., 0.]]],\n", "\n", "\n", " [[[0., 0., 0., 0., 0.],\n", " [0., 0., 0., 0., 0.],\n", " [0., 0., 0., 0., 0.],\n", " [0., 0., 0., 0., 0.],\n", " [0., 0., 0., 0., 0.]]],\n", "\n", "\n", " [[[0., 0., 0., 0., 0.],\n", " [0., 0., 0., 0., 0.],\n", " [0., 0., 0., 0., 0.],\n", " [0., 0., 0., 0., 0.],\n", " [0., 0., 0., 0., 0.]]],\n", "\n", "\n", " [[[1., 1., 1., 1., 1.],\n", " [1., 1., 1., 1., 1.],\n", " [1., 1., 1., 1., 1.],\n", " [1., 1., 1., 1., 1.],\n", " [1., 1., 1., 1., 1.]]],\n", "\n", "\n", " [[[1., 1., 1., 1., 1.],\n", " [1., 1., 1., 1., 1.],\n", " [1., 1., 1., 1., 1.],\n", " [1., 1., 1., 1., 1.],\n", " [1., 1., 1., 1., 1.]]],\n", "\n", "\n", " [[[1., 1., 1., 1., 1.],\n", " [1., 1., 1., 1., 1.],\n", " [1., 1., 1., 1., 1.],\n", " [1., 1., 1., 1., 1.],\n", " [1., 1., 1., 1., 1.]]],\n", "\n", "\n", " [[[0., 0., 0., 0., 0.],\n", " [0., 0., 0., 0., 0.],\n", " [0., 0., 0., 0., 0.],\n", " [0., 0., 0., 0., 0.],\n", " [0., 0., 0., 0., 0.]]],\n", "\n", "\n", " [[[1., 1., 1., 1., 1.],\n", " [1., 1., 1., 1., 1.],\n", " [1., 1., 1., 1., 1.],\n", " [1., 1., 1., 1., 1.],\n", " [1., 1., 1., 1., 1.]]],\n", "\n", "\n", " [[[1., 1., 1., 1., 1.],\n", " [1., 1., 1., 1., 1.],\n", " [1., 1., 1., 1., 1.],\n", " [1., 1., 1., 1., 1.],\n", " [1., 1., 1., 1., 1.]]],\n", "\n", "\n", " [[[0., 0., 0., 0., 0.],\n", " [0., 0., 0., 0., 0.],\n", " [0., 0., 0., 0., 0.],\n", " [0., 0., 0., 0., 0.],\n", " [0., 0., 0., 0., 0.]]],\n", "\n", "\n", " [[[1., 1., 1., 1., 1.],\n", " [1., 1., 1., 1., 1.],\n", " [1., 1., 1., 1., 1.],\n", " [1., 1., 1., 1., 1.],\n", " [1., 1., 1., 1., 1.]]],\n", "\n", "\n", " [[[1., 1., 1., 1., 1.],\n", " [1., 1., 1., 1., 1.],\n", " [1., 1., 1., 1., 1.],\n", " [1., 1., 1., 1., 1.],\n", " [1., 1., 1., 1., 1.]]],\n", "\n", "\n", " [[[0., 0., 0., 0., 0.],\n", " [0., 0., 0., 0., 0.],\n", " [0., 0., 0., 0., 0.],\n", " [0., 0., 0., 0., 0.],\n", " [0., 0., 0., 0., 0.]]],\n", "\n", "\n", " [[[0., 0., 0., 0., 0.],\n", " [0., 0., 0., 0., 0.],\n", " [0., 0., 0., 0., 0.],\n", " [0., 0., 0., 0., 0.],\n", " [0., 0., 0., 0., 0.]]],\n", "\n", "\n", " [[[1., 1., 1., 1., 1.],\n", " [1., 1., 1., 1., 1.],\n", " [1., 1., 1., 1., 1.],\n", " [1., 1., 1., 1., 1.],\n", " [1., 1., 1., 1., 1.]]],\n", "\n", "\n", " [[[0., 0., 0., 0., 0.],\n", " [0., 0., 0., 0., 0.],\n", " [0., 0., 0., 0., 0.],\n", " [0., 0., 0., 0., 0.],\n", " [0., 0., 0., 0., 0.]]]], device='cuda:0')\n" ] } ], "source": [ "# show the mask of `conv1`\n", "\n", "print(model.conv1.weight_mask)" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tensor([[[[ 1.5338e-01, -1.1766e-01, -2.6654e-01, -2.9445e-02, -1.4650e-01],\n", " [-1.8796e-01, -2.9882e-01, 6.9725e-02, 2.1561e-01, 6.5688e-02],\n", " [ 1.5274e-01, -9.8471e-03, 3.2303e-01, 1.3472e-03, 1.7235e-01],\n", " [ 1.1804e-01, 2.2535e-01, -8.3370e-02, -3.4553e-02, -1.2529e-01],\n", " [-6.6012e-02, -2.0272e-02, -1.8797e-01, -4.6882e-02, -8.3206e-02]]],\n", "\n", "\n", " [[[-0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, -0.0000e+00],\n", " [-0.0000e+00, -0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [-0.0000e+00, 0.0000e+00, 0.0000e+00, -0.0000e+00, -0.0000e+00],\n", " [ 0.0000e+00, 0.0000e+00, -0.0000e+00, 0.0000e+00, -0.0000e+00],\n", " [ 0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00]]],\n", "\n", "\n", " [[[ 3.8332e-02, -1.4270e-01, -1.9585e-01, 2.2653e-01, 1.0104e-01],\n", " [-2.7956e-03, -1.4108e-01, -1.4694e-01, -1.3525e-01, 2.6959e-01],\n", " [ 1.9522e-01, -1.2281e-01, -1.9173e-01, -1.8910e-02, 3.1572e-03],\n", " [-1.0580e-01, -2.5239e-02, -5.8266e-02, -6.5815e-02, 6.6433e-02],\n", " [ 8.9601e-02, 7.1189e-02, -2.4255e-01, 1.5746e-01, -1.4708e-01]]],\n", "\n", "\n", " [[[-0.0000e+00, -0.0000e+00, -0.0000e+00, 0.0000e+00, -0.0000e+00],\n", " [-0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [-0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, 0.0000e+00],\n", " [ 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [-0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, 0.0000e+00]]],\n", "\n", "\n", " [[[-0.0000e+00, 0.0000e+00, 0.0000e+00, -0.0000e+00, -0.0000e+00],\n", " [-0.0000e+00, -0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [ 0.0000e+00, 0.0000e+00, -0.0000e+00, -0.0000e+00, 0.0000e+00],\n", " [-0.0000e+00, -0.0000e+00, 0.0000e+00, -0.0000e+00, 0.0000e+00],\n", " [ 0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00]]],\n", "\n", "\n", " [[[ 0.0000e+00, 0.0000e+00, -0.0000e+00, 0.0000e+00, -0.0000e+00],\n", " [ 0.0000e+00, -0.0000e+00, 0.0000e+00, -0.0000e+00, 0.0000e+00],\n", " [ 0.0000e+00, 0.0000e+00, 0.0000e+00, -0.0000e+00, 0.0000e+00],\n", " [-0.0000e+00, 0.0000e+00, -0.0000e+00, -0.0000e+00, 0.0000e+00],\n", " [-0.0000e+00, 0.0000e+00, -0.0000e+00, -0.0000e+00, 0.0000e+00]]],\n", "\n", "\n", " [[[ 0.0000e+00, 0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00],\n", " [ 0.0000e+00, 0.0000e+00, 0.0000e+00, -0.0000e+00, 0.0000e+00],\n", " [-0.0000e+00, 0.0000e+00, -0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [ 0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, 0.0000e+00],\n", " [-0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00]]],\n", "\n", "\n", " [[[ 4.0772e-02, -8.2086e-02, -2.7555e-01, -3.2547e-01, -1.2226e-01],\n", " [-5.9877e-02, 9.8567e-02, 2.5186e-01, -1.0280e-01, -2.3416e-01],\n", " [ 8.5760e-02, 1.0896e-01, 1.4898e-01, 2.1579e-01, 8.5297e-02],\n", " [ 5.4720e-02, -1.7226e-01, -7.2518e-02, 6.7099e-03, -1.6011e-03],\n", " [-8.9944e-02, 1.7404e-01, -3.6985e-02, 1.8602e-01, 7.2353e-02]]],\n", "\n", "\n", " [[[ 1.6276e-02, -9.6439e-02, -9.6085e-02, -2.4267e-01, -1.8521e-01],\n", " [ 6.3310e-02, 1.7866e-01, 1.1694e-01, -1.4464e-01, -2.7711e-01],\n", " [-2.4514e-02, 2.2222e-01, 2.1053e-01, -1.4271e-01, 8.7045e-02],\n", " [-1.9207e-01, -5.4719e-02, -5.7775e-03, -1.0034e-05, -1.0923e-01],\n", " [-2.4006e-02, 2.3780e-02, 1.8988e-01, 2.4734e-01, 4.8097e-02]]],\n", "\n", "\n", " [[[ 1.1335e-01, -5.8451e-02, 5.2440e-02, -1.3223e-01, -2.5534e-02],\n", " [ 9.1323e-02, -6.0707e-02, 2.3524e-01, 2.4992e-01, 8.7842e-02],\n", " [ 2.9002e-02, 3.5379e-02, -5.9689e-02, -2.8363e-03, 1.8618e-01],\n", " [-2.9671e-01, 8.1830e-03, 1.1076e-01, -5.4118e-02, -6.1685e-02],\n", " [-1.7580e-01, -3.4534e-01, -3.9250e-01, -2.7569e-01, -2.6131e-01]]],\n", "\n", "\n", " [[[ 0.0000e+00, -0.0000e+00, -0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [-0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, 0.0000e+00],\n", " [-0.0000e+00, -0.0000e+00, -0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [-0.0000e+00, -0.0000e+00, -0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [-0.0000e+00, 0.0000e+00, -0.0000e+00, -0.0000e+00, 0.0000e+00]]],\n", "\n", "\n", " [[[ 6.3785e-02, -1.8044e-01, -1.0190e-01, -1.3588e-01, 8.5433e-02],\n", " [ 2.0675e-01, 3.3238e-02, 9.2437e-02, 1.1799e-01, 2.1111e-01],\n", " [-5.2138e-02, 1.5790e-01, 1.8151e-01, 8.0470e-02, 1.0131e-01],\n", " [-4.4786e-02, 1.1771e-01, 2.1706e-02, -1.2563e-01, -2.1142e-01],\n", " [-2.3589e-01, -2.1154e-01, -1.7890e-01, -2.7769e-01, -1.2512e-01]]],\n", "\n", "\n", " [[[ 1.9133e-01, 2.4711e-01, 1.0413e-01, -1.9187e-01, -3.0991e-01],\n", " [-1.2382e-01, 8.3641e-03, -5.6734e-02, 5.8376e-02, 2.2880e-02],\n", " [-3.1734e-01, -1.0637e-02, -5.5974e-02, 1.0676e-01, -1.1080e-02],\n", " [-2.2980e-01, 2.0486e-01, 1.0147e-01, 1.4484e-01, 5.2265e-02],\n", " [ 7.4410e-02, 2.2806e-02, 8.5137e-02, -2.1809e-01, 3.1704e-02]]],\n", "\n", "\n", " [[[-0.0000e+00, -0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [-0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, -0.0000e+00],\n", " [ 0.0000e+00, 0.0000e+00, 0.0000e+00, -0.0000e+00, -0.0000e+00],\n", " [-0.0000e+00, -0.0000e+00, 0.0000e+00, -0.0000e+00, 0.0000e+00],\n", " [-0.0000e+00, 0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00]]],\n", "\n", "\n", " [[[ 7.6976e-02, -3.8627e-02, 1.2610e-01, 1.1994e-01, 2.1706e-03],\n", " [ 7.4357e-02, 6.7929e-02, 3.1386e-02, 1.4606e-01, 2.1429e-01],\n", " [-2.6569e-01, -4.2631e-04, -3.6654e-02, -3.0967e-02, -9.4961e-02],\n", " [-2.0192e-01, -3.5423e-01, -2.5246e-01, -3.5092e-01, -2.4159e-01],\n", " [ 1.7636e-02, 1.3744e-01, -1.0306e-01, 8.8370e-02, 7.3258e-02]]],\n", "\n", "\n", " [[[ 2.0016e-01, 1.0956e-01, -5.9223e-02, 6.4871e-03, -2.4165e-01],\n", " [ 5.6283e-02, 1.7276e-01, -2.2316e-01, -1.6699e-01, -7.0742e-02],\n", " [ 2.6179e-01, -2.5102e-01, -2.0774e-01, -9.6413e-02, 3.4367e-02],\n", " [-9.1882e-02, -2.9195e-01, -8.7432e-02, 1.0144e-01, -2.0559e-02],\n", " [-2.5668e-01, -9.8016e-02, 1.1103e-01, -3.0233e-02, 1.1076e-01]]],\n", "\n", "\n", " [[[ 0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00],\n", " [ 0.0000e+00, 0.0000e+00, -0.0000e+00, -0.0000e+00, 0.0000e+00],\n", " [-0.0000e+00, 0.0000e+00, -0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [-0.0000e+00, 0.0000e+00, 0.0000e+00, -0.0000e+00, 0.0000e+00],\n", " [ 0.0000e+00, -0.0000e+00, 0.0000e+00, -0.0000e+00, -0.0000e+00]]],\n", "\n", "\n", " [[[ 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [-0.0000e+00, -0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [-0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, 0.0000e+00],\n", " [-0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, 0.0000e+00],\n", " [-0.0000e+00, 0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00]]],\n", "\n", "\n", " [[[-1.8396e-01, -6.8774e-03, -1.6675e-01, 7.1980e-03, 1.9922e-02],\n", " [ 1.3416e-01, -1.1450e-01, -1.5277e-01, -6.5713e-02, -9.5435e-02],\n", " [ 1.5406e-01, -9.1235e-02, -1.0880e-01, -7.1603e-02, -9.5575e-02],\n", " [ 2.1772e-01, 8.4073e-02, -2.5264e-01, -2.1428e-01, 1.9537e-01],\n", " [ 1.3124e-01, 7.9532e-02, -2.4044e-01, -1.5717e-01, 1.6562e-01]]],\n", "\n", "\n", " [[[ 0.0000e+00, -0.0000e+00, -0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [-0.0000e+00, -0.0000e+00, -0.0000e+00, 0.0000e+00, 0.0000e+00],\n", " [-0.0000e+00, 0.0000e+00, -0.0000e+00, 0.0000e+00, -0.0000e+00],\n", " [-0.0000e+00, -0.0000e+00, 0.0000e+00, 0.0000e+00, -0.0000e+00],\n", " [-0.0000e+00, 0.0000e+00, 0.0000e+00, -0.0000e+00, -0.0000e+00]]]],\n", " device='cuda:0')\n" ] } ], "source": [ "# use a dummy input to apply the sparsify.\n", "\n", "model(torch.rand(1, 1, 28, 28).to(device))\n", "\n", "# the weights of `conv1` have been sparsified.\n", "\n", "print(model.conv1.module.weight.data)" ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[2021-07-26 22:26:05] INFO (nni.compression.pytorch.compressor/MainThread) Model state_dict saved to pruned_naive_mnist_l1filter.pth\n", "[2021-07-26 22:26:05] INFO (nni.compression.pytorch.compressor/MainThread) Mask dict saved to mask_naive_mnist_l1filter.pth\n" ] } ], "source": [ "# export the sparsified model state to './pruned_naive_mnist_l1filter.pth'.\n", "# export the mask to './mask_naive_mnist_l1filter.pth'.\n", "\n", "pruner.export_model(model_path='pruned_naive_mnist_l1filter.pth', mask_path='mask_naive_mnist_l1filter.pth')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# 4. Speed Up" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "NaiveModel(\n", " (conv1): Conv2d(1, 20, kernel_size=(5, 5), stride=(1, 1))\n", " (conv2): Conv2d(20, 50, kernel_size=(5, 5), stride=(1, 1))\n", " (fc1): Linear(in_features=800, out_features=500, bias=True)\n", " (fc2): Linear(in_features=500, out_features=10, bias=True)\n", " (relu1): ReLU6()\n", " (relu2): ReLU6()\n", " (relu3): ReLU6()\n", " (max_pool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n", " (max_pool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n", ")\n" ] } ], "source": [ "# If you use a wrapped model, don't forget to unwrap it.\n", "\n", "pruner._unwrap_model()\n", "\n", "# the model has been unwrapped.\n", "\n", "print(model)" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ ":22: TracerWarning: Converting a tensor to a Python index might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!\n", " x = x.view(-1, x.size()[1:].numel())\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) start to speed up the model\n", "[2021-07-26 22:26:18] INFO (FixMaskConflict/MainThread) {'conv1': 1, 'conv2': 1}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "[2021-07-26 22:26:18] INFO (FixMaskConflict/MainThread) dim0 sparsity: 0.500000\n", "[2021-07-26 22:26:18] INFO (FixMaskConflict/MainThread) dim1 sparsity: 0.000000\n", "[2021-07-26 22:26:18] INFO (FixMaskConflict/MainThread) Dectected conv prune dim\" 0\n", "[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) infer module masks...\n", "[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) Update mask for conv1\n", "[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) Update mask for relu1\n", "[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) Update mask for max_pool1\n", "[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) Update mask for conv2\n", "[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) Update mask for relu2\n", "[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) Update mask for max_pool2\n", "[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) Update mask for .aten::view.9\n", "[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.jit_translate/MainThread) View Module output size: [-1, 800]\n", "[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) Update mask for fc1\n", "[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) Update mask for relu3\n", "[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) Update mask for fc2\n", "[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) Update mask for .aten::log_softmax.10\n", "[2021-07-26 22:26:18] ERROR (nni.compression.pytorch.speedup.jit_translate/MainThread) aten::log_softmax is not Supported! Please report an issue at https://github.com/microsoft/nni. Thanks~\n", "[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) Update indirect sparsity for .aten::log_softmax.10\n", "[2021-07-26 22:26:18] WARNING (nni.compression.pytorch.speedup.compressor/MainThread) Note: .aten::log_softmax.10 does not have corresponding mask inference object\n", "[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) Update indirect sparsity for fc2\n", "[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) Update the indirect sparsity for the fc2\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) Update indirect sparsity for relu3\n", "[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) Update the indirect sparsity for the relu3\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) Update indirect sparsity for fc1\n", "[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) Update the indirect sparsity for the fc1\n", "[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) Update indirect sparsity for .aten::view.9\n", "[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) Update the indirect sparsity for the .aten::view.9\n", "[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) Update indirect sparsity for max_pool2\n", "[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) Update the indirect sparsity for the max_pool2\n", "[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) Update indirect sparsity for relu2\n", "[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) Update the indirect sparsity for the relu2\n", "[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) Update indirect sparsity for conv2\n", "[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) Update the indirect sparsity for the conv2\n", "[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) Update indirect sparsity for max_pool1\n", "[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) Update the indirect sparsity for the max_pool1\n", "[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) Update indirect sparsity for relu1\n", "[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) Update the indirect sparsity for the relu1\n", "[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) Update indirect sparsity for conv1\n", "[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) Update the indirect sparsity for the conv1\n", "[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) resolve the mask conflict\n", "[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) replace compressed modules...\n", "[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) replace module (name: conv1, op_type: Conv2d)\n", "[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) replace module (name: relu1, op_type: ReLU6)\n", "[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) replace module (name: max_pool1, op_type: MaxPool2d)\n", "[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) replace module (name: conv2, op_type: Conv2d)\n", "[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) replace module (name: relu2, op_type: ReLU6)\n", "[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) replace module (name: max_pool2, op_type: MaxPool2d)\n", "[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) Warning: cannot replace (name: .aten::view.9, op_type: aten::view) which is func type\n", "[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) replace module (name: fc1, op_type: Linear)\n", "[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compress_modules/MainThread) replace linear with new in_features: 800, out_features: 500\n", "[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) replace module (name: relu3, op_type: ReLU6)\n", "[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) replace module (name: fc2, op_type: Linear)\n", "[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compress_modules/MainThread) replace linear with new in_features: 500, out_features: 10\n", "[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) Warning: cannot replace (name: .aten::log_softmax.10, op_type: aten::log_softmax) which is func type\n", "[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) speedup done\n" ] } ], "source": [ "from nni.compression.pytorch import ModelSpeedup\n", "\n", "m_speedup = ModelSpeedup(model, dummy_input=torch.rand(10, 1, 28, 28).to(device), masks_file='mask_naive_mnist_l1filter.pth')\n", "m_speedup.speedup_model()" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "NaiveModel(\n", " (conv1): Conv2d(1, 10, kernel_size=(5, 5), stride=(1, 1))\n", " (conv2): Conv2d(10, 50, kernel_size=(5, 5), stride=(1, 1))\n", " (fc1): Linear(in_features=800, out_features=500, bias=True)\n", " (fc2): Linear(in_features=500, out_features=10, bias=True)\n", " (relu1): ReLU6()\n", " (relu2): ReLU6()\n", " (relu3): ReLU6()\n", " (max_pool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n", " (max_pool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n", ")\n" ] } ], "source": [ "# the `conv1` has been replace from `Conv2d(1, 20, kernel_size=(5, 5), stride=(1, 1))` to `Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))`\n", "# and the following layer `conv2` has also changed because the input channel of `conv2` should aware the output channel of `conv1`.\n", "\n", "print(model)" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Train Epoch: 0 [0/60000 (0%)]\tLoss: 0.306930\n", "Train Epoch: 0 [6400/60000 (11%)]\tLoss: 0.045807\n", "Train Epoch: 0 [12800/60000 (21%)]\tLoss: 0.049293\n", "Train Epoch: 0 [19200/60000 (32%)]\tLoss: 0.031464\n", "Train Epoch: 0 [25600/60000 (43%)]\tLoss: 0.005392\n", "Train Epoch: 0 [32000/60000 (53%)]\tLoss: 0.005652\n", "Train Epoch: 0 [38400/60000 (64%)]\tLoss: 0.040619\n", "Train Epoch: 0 [44800/60000 (75%)]\tLoss: 0.016515\n", "Train Epoch: 0 [51200/60000 (85%)]\tLoss: 0.092886\n", "Train Epoch: 0 [57600/60000 (96%)]\tLoss: 0.041380\n", "\n", "Test set: Average loss: 0.0257, Accuracy: 9917/10000 (99%)\n", "\n" ] } ], "source": [ "# finetune the model to recover the accuracy.\n", "\n", "optimizer = torch.optim.SGD(model.parameters(), lr=0.01)\n", "\n", "for epoch in range(0, 1):\n", " trainer(model, optimizer, criterion, epoch)\n", " evaluator(model)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# 5. Prepare config_list for quantization" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [], "source": [ "config_list = [{\n", " 'quant_types': ['weight', 'input'],\n", " 'quant_bits': {'weight': 8, 'input': 8},\n", " 'op_names': ['conv1', 'conv2']\n", "}]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# 6. Choose a quantizer and quantizing" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "NaiveModel(\n", " (conv1): QuantizerModuleWrapper(\n", " (module): Conv2d(1, 10, kernel_size=(5, 5), stride=(1, 1))\n", " )\n", " (conv2): QuantizerModuleWrapper(\n", " (module): Conv2d(10, 50, kernel_size=(5, 5), stride=(1, 1))\n", " )\n", " (fc1): Linear(in_features=800, out_features=500, bias=True)\n", " (fc2): Linear(in_features=500, out_features=10, bias=True)\n", " (relu1): ReLU6()\n", " (relu2): ReLU6()\n", " (relu3): ReLU6()\n", " (max_pool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n", " (max_pool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n", ")" ] }, "execution_count": 20, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from nni.algorithms.compression.pytorch.quantization import QAT_Quantizer\n", "\n", "quantizer = QAT_Quantizer(model, config_list, optimizer)\n", "quantizer.compress()" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Train Epoch: 0 [0/60000 (0%)]\tLoss: 0.004960\n", "Train Epoch: 0 [6400/60000 (11%)]\tLoss: 0.036269\n", "Train Epoch: 0 [12800/60000 (21%)]\tLoss: 0.018744\n", "Train Epoch: 0 [19200/60000 (32%)]\tLoss: 0.021916\n", "Train Epoch: 0 [25600/60000 (43%)]\tLoss: 0.003095\n", "Train Epoch: 0 [32000/60000 (53%)]\tLoss: 0.003947\n", "Train Epoch: 0 [38400/60000 (64%)]\tLoss: 0.032094\n", "Train Epoch: 0 [44800/60000 (75%)]\tLoss: 0.017358\n", "Train Epoch: 0 [51200/60000 (85%)]\tLoss: 0.083886\n", "Train Epoch: 0 [57600/60000 (96%)]\tLoss: 0.040433\n", "\n", "Test set: Average loss: 0.0247, Accuracy: 9917/10000 (99%)\n", "\n" ] } ], "source": [ "# finetune the model for calibration.\n", "\n", "for epoch in range(0, 1):\n", " trainer(model, optimizer, criterion, epoch)\n", " evaluator(model)" ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[2021-07-26 22:34:41] INFO (nni.compression.pytorch.compressor/MainThread) Model state_dict saved to quantized_naive_mnist_l1filter.pth\n", "[2021-07-26 22:34:41] INFO (nni.compression.pytorch.compressor/MainThread) Mask dict saved to calibration_naive_mnist_l1filter.pth\n" ] }, { "data": { "text/plain": [ "{'conv1': {'weight_bit': 8,\n", " 'tracked_min_input': -0.42417848110198975,\n", " 'tracked_max_input': 2.8212687969207764},\n", " 'conv2': {'weight_bit': 8,\n", " 'tracked_min_input': 0.0,\n", " 'tracked_max_input': 4.246923446655273}}" ] }, "execution_count": 22, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# export the sparsified model state to './quantized_naive_mnist_l1filter.pth'.\n", "# export the calibration config to './calibration_naive_mnist_l1filter.pth'.\n", "\n", "quantizer.export_model(model_path='quantized_naive_mnist_l1filter.pth', calibration_path='calibration_naive_mnist_l1filter.pth')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# 7. Speed Up" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# speed up with tensorRT\n", "\n", "engine = ModelSpeedupTensorRT(model, (32, 1, 28, 28), config=calibration_config, batchsize=32)\n", "engine.compress()" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.8" } }, "nbformat": 4, "nbformat_minor": 5 }