demo_rewrite.py 3.88 KB
Newer Older
limm's avatar
limm committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
# Copyright (c) OpenMMLab. All rights reserved.
import asyncio
import os
import shutil

import torch
from pyppeteer import launch
from torchvision.models import resnet18

from mmdeploy.core import FUNCTION_REWRITER, RewriterContext, patch_model
from mmdeploy.utils import get_root_logger


@FUNCTION_REWRITER.register_rewriter(
    func_name='torchvision.models.ResNet._forward_impl')
def forward_of_resnet(self, x):
    """Rewrite the forward implementation of resnet.

    Early return the feature map after two down-sampling steps.
    """
    x = self.conv1(x)
    x = self.bn1(x)
    x = self.relu(x)
    x = self.maxpool(x)

    x = self.layer1(x)
    return x


def rewrite_resnet18(original_path: str, rewritten_path: str):
    # prepare inputs and original model
    inputs = torch.rand(1, 3, 224, 224)
    original_model = resnet18(pretrained=False)

    # export original model
    torch.onnx.export(original_model, inputs, original_path)

    # patch model
    patched_model = patch_model(original_model, cfg={}, backend='default')

    # export rewritten onnx under a rewriter context manager
    with RewriterContext(cfg={}, backend='default'), torch.no_grad():
        torch.onnx.export(patched_model, inputs, rewritten_path)


def screen_size():
    """Get windows size through tkinter."""
    import tkinter
    tk = tkinter.Tk()
    width = tk.winfo_screenwidth()
    height = tk.winfo_screenheight()
    tk.quit()
    return width, height


async def visualize(original_path: str, rewritten_path: str):
    # launch a web browser
    browser = await launch(headless=False, args=['--start-maximized'])
    # create two new pages
    page2 = await browser.newPage()
    page1 = await browser.newPage()
    # go to netron.app
    width, height = screen_size()
    await page1.setViewport({'width': width, 'height': height})
    await page2.setViewport({'width': width, 'height': height})
    await page1.goto('https://netron.app/')
    await page2.goto('https://netron.app/')
    await asyncio.sleep(2)

    # open local two onnx files
    mupinput1 = await page1.querySelector("input[type='file']")
    mupinput2 = await page2.querySelector("input[type='file']")
    await mupinput1.uploadFile(original_file_path)
    await mupinput2.uploadFile(rewritten_file_path)
    await asyncio.sleep(4)
    for _ in range(6):
        await page1.click('#zoom-out-button')
        await asyncio.sleep(0.3)
    await asyncio.sleep(1)
    await page1.screenshot({'path': original_path.replace('.onnx', '.png')},
                           clip={
                               'x': width / 4,
                               'y': 0,
                               'width': width / 2,
                               'height': height
                           })
    await page2.screenshot({'path': rewritten_path.replace('.onnx', '.png')},
                           clip={
                               'x': width / 4,
                               'y': 0,
                               'width': width / 2,
                               'height': height
                           })
    await browser.close()


if __name__ == '__main__':
    tmp_dir = os.getcwd() + '/tmp'
    if not os.path.exists(tmp_dir):
        os.mkdir(tmp_dir)
    original_file_path = os.path.join(tmp_dir, 'original.onnx')
    rewritten_file_path = os.path.join(tmp_dir, 'rewritten.onnx')
    logger = get_root_logger()
    logger.info('Generating resnet18 and its rewritten model...')
    rewrite_resnet18(original_file_path, rewritten_file_path)

    logger.info('Visualizing models through netron...')
    asyncio.get_event_loop().run_until_complete(
        visualize(original_file_path, rewritten_file_path))
    import mmcv
    image1 = mmcv.imread(original_file_path.replace('.onnx', '.png'))
    image2 = mmcv.imread(rewritten_file_path.replace('.onnx', '.png'))
    mmcv.imshow(image1, win_name='original')
    mmcv.imshow(image2, win_name='rewritten')
    shutil.rmtree(tmp_dir)