mnist_qat.py 18.9 KB
Newer Older
yan.yan's avatar
yan.yan committed
1
# Copyright 2021 Yan Yan
2
#
yan.yan's avatar
yan.yan committed
3
4
5
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
6
#
yan.yan's avatar
yan.yan committed
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
yan.yan's avatar
yan.yan committed
9
10
11
12
13
14
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

yanyan's avatar
yanyan committed
15
from __future__ import print_function
yan.yan's avatar
yan.yan committed
16

yanyan's avatar
yanyan committed
17
import argparse
yan.yan's avatar
yan.yan committed
18
19
20
21
import contextlib
import copy
from typing import Dict, Optional

yanyan's avatar
yanyan committed
22
import torch
yan.yan's avatar
yan.yan committed
23
24
25
26
import torch.ao.quantization
import torch.ao.quantization.quantize_fx as qfx
import torch.cuda.amp
import torch.fx
yanyan's avatar
yanyan committed
27
28
29
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
yan.yan's avatar
yan.yan committed
30
31
32
33
from torch.ao.quantization import (DeQuantStub, QuantStub,
                                   get_default_qconfig_mapping)
from torch.ao.quantization.fx._lower_to_native_backend import \
    STATIC_LOWER_FUSED_MODULE_MAP, STATIC_LOWER_MODULE_MAP
yanyan's avatar
yanyan committed
34
from torch.optim.lr_scheduler import StepLR
yan.yan's avatar
yan.yan committed
35
36
37
from torchvision import datasets, transforms

import spconv.pytorch as spconv
yan.yan's avatar
yan.yan committed
38
39
import spconv.pytorch.quantization as spconvq
from spconv.pytorch.quantization import get_default_spconv_trt_ptq_qconfig
yan.yan's avatar
yan.yan committed
40
41
42
43
44
45
46
from spconv.pytorch.quantization.backend_cfg import \
    SPCONV_STATIC_LOWER_FUSED_MODULE_MAP, SPCONV_STATIC_LOWER_MODULE_MAP
from spconv.pytorch.quantization.core import quantize_per_tensor
from spconv.pytorch.quantization.fake_q import \
    get_default_spconv_qconfig_mapping
from spconv.pytorch.quantization.intrinsic.modules import SpconvBnAddReLUNd, SpconvAddReLUNd
import spconv.pytorch.quantization.intrinsic.quantized as snniq
yan.yan's avatar
v2.1  
yan.yan committed
47
48
49

@contextlib.contextmanager
def identity_ctx():
50
    yield
yanyan's avatar
yanyan committed
51

yan.yan's avatar
yan.yan committed
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
class SubMConvBNReLU(spconv.SparseSequential):
    def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1):
        padding = (kernel_size - 1) // 2
        super(SubMConvBNReLU, self).__init__(
            spconv.SubMConv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, bias=False),
            nn.BatchNorm1d(out_planes, momentum=0.1),
            # Replace with ReLU
            nn.ReLU(inplace=False)
        )

class SparseConvBNReLU(spconv.SparseSequential):
    def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1):
        padding = (kernel_size - 1) // 2
        super(SparseConvBNReLU, self).__init__(
            spconv.SparseConv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, bias=False),
            nn.BatchNorm1d(out_planes, momentum=0.1),
            # Replace with ReLU
            nn.ReLU(inplace=False)
        )
yanyan's avatar
yanyan committed
71

yan.yan's avatar
yan.yan committed
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
class SparseBasicBlock(spconv.SparseModule):
    """residual block that supported by spconv quantization.
    """
    expansion = 1
    def __init__(self,
                 in_planes, out_planes,
                 stride=1,
                 downsample=None):
        spconv.SparseModule.__init__(self)
        conv1 = spconv.SubMConv2d(in_planes, out_planes, 3, stride, 1, bias=False)
        conv2 = spconv.SubMConv2d(out_planes, out_planes, 3, stride, 1, bias=False)

        norm1 = nn.BatchNorm1d(out_planes, momentum=0.1)
        norm2 = nn.BatchNorm1d(out_planes, momentum=0.1)

        self.conv1_bn_relu = spconv.SparseSequential(conv=conv1, bn=norm1, relu=nn.ReLU(inplace=True))
        self.conv2_bn = spconv.SparseSequential(conv=conv2, bn=norm2)

yan.yan's avatar
yan.yan committed
90
        self.relu = spconv.SparseReLU(inplace=True)
yan.yan's avatar
yan.yan committed
91
92
93
94
        self.downsample = downsample
        self.iden_for_fx_match = spconv.SparseIdentity()

    def forward(self, x: spconv.SparseConvTensor):
yan.yan's avatar
yan.yan committed
95
        identity = x
yan.yan's avatar
yan.yan committed
96
97
        # if self.training:
        #     assert x.features.dim() == 2, f'x.features.dim()={x.features.dim()}'
yan.yan's avatar
yan.yan committed
98
99
        out = self.conv1_bn_relu(x)
        out = self.conv2_bn(out)
yan.yan's avatar
yan.yan committed
100
101
102

        if self.downsample is not None:
            identity = self.downsample(x)
yan.yan's avatar
yan.yan committed
103
        out = self.relu(out + identity)
yan.yan's avatar
yan.yan committed
104
105
        return out

yanyan's avatar
yanyan committed
106
107
108
109
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.net = spconv.SparseSequential(
yan.yan's avatar
yan.yan committed
110
111
112
            SubMConvBNReLU(1, 32, 3),
            SubMConvBNReLU(32, 64, 3),
            SparseConvBNReLU(64, 64, 2, 2),
113
            spconv.ToDense(),
yanyan's avatar
yanyan committed
114
        )
yan.yan's avatar
yan.yan committed
115
        self.fc1 = nn.Linear(14 * 14 * 64, 128)
yanyan's avatar
yanyan committed
116
117
118
        self.fc2 = nn.Linear(128, 10)
        self.dropout1 = nn.Dropout2d(0.25)
        self.dropout2 = nn.Dropout2d(0.5)
yan.yan's avatar
yan.yan committed
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
        self.quant = QuantStub()
        self.dequant = DeQuantStub()
    
    def forward(self, x_sp: spconv.SparseConvTensor):
    # def forward(self, features: torch.Tensor, indices: torch.Tensor, batch_size: int):
        # x: [N, 28, 28, 1], must be NHWC tensor
        # x = self.quant(x)
        # x_sp = spconv.SparseConvTensor.from_dense(x.reshape(-1, 28, 28, 1))
        # x_sp = spconv.SparseConvTensor(features, indices, [28, 28], batch_size)
        # create SparseConvTensor manually: see SparseConvTensor.from_dense
        x = self.net(x_sp)
        x = torch.flatten(x, 1)
        x = self.dropout1(x)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)
        # x = self.dequant(x)
        output = F.log_softmax(x, dim=1)
        return output
yanyan's avatar
yanyan committed
139

yan.yan's avatar
yan.yan committed
140
141
142
143
144
145
146
147
148
149
150
class NetV2(nn.Module):
    def __init__(self):
        super(NetV2, self).__init__()
        self.net = spconv.SparseSequential(
            SubMConvBNReLU(1, 32, 3),
            SubMConvBNReLU(32, 64, 3),
            SparseConvBNReLU(64, 64, 2, 2),
            spconv.ToDense(),
        )
        self.fc1 = nn.Linear(14 * 14 * 64, 128)
        self.fc2 = nn.Linear(128, 10)
yan.yan's avatar
yan.yan committed
151
152
        # self.dropout1 = nn.Dropout2d(0.25)
        # self.dropout2 = nn.Dropout2d(0.5)
yan.yan's avatar
yan.yan committed
153
154
155
156
        self.quant = QuantStub()
        self.dequant = DeQuantStub()
    
    def forward(self, features: torch.Tensor, indices: torch.Tensor, batch_size: int):
yanyan's avatar
yanyan committed
157
        # x: [N, 28, 28, 1], must be NHWC tensor
yan.yan's avatar
yan.yan committed
158
159
160
        x = self.quant(features)
        # x_sp = spconv.SparseConvTensor.from_dense(x.reshape(-1, 28, 28, 1))
        x_sp = spconv.SparseConvTensor(features, indices, [28, 28], batch_size)
yanyan's avatar
yanyan committed
161
162
163
        # create SparseConvTensor manually: see SparseConvTensor.from_dense
        x = self.net(x_sp)
        x = torch.flatten(x, 1)
yan.yan's avatar
yan.yan committed
164
        # x = self.dropout1(x)
yanyan's avatar
yanyan committed
165
166
        x = self.fc1(x)
        x = F.relu(x)
yan.yan's avatar
yan.yan committed
167
        # x = self.dropout2(x)
yanyan's avatar
yanyan committed
168
        x = self.fc2(x)
yan.yan's avatar
yan.yan committed
169
170
171
172
173
174
        x = self.dequant(x)
        output = F.log_softmax(x, dim=1)
        return output

class NetPTQ(nn.Module):
    """pytorch currently don't support cuda int8 inference, so
yan.yan's avatar
yan.yan committed
175
    we build a pure sparse network here.
yan.yan's avatar
yan.yan committed
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
    """
    def __init__(self):
        super(NetPTQ, self).__init__()
        self.net = spconv.SparseSequential(
            SubMConvBNReLU(1, 32, 3),
            SubMConvBNReLU(32, 64, 3),
            SparseConvBNReLU(64, 64, 2, 2), # 14x14
            SparseConvBNReLU(64, 64, 2, 2), # 7x7
            SparseConvBNReLU(64, 64, 3, 2, 1), # 4x4
            spconv.SparseConv2d(64, 10, 4, 4),
            spconv.ToDense(),
        )
        # self.fc1 = nn.Linear(64 * 1 * 1, 128)
        # self.fc2 = nn.Linear(128, 10)
        # self.dropout1 = nn.Dropout2d(0.25)
        # self.dropout2 = nn.Dropout2d(0.5)

        self.quant = QuantStub()
        self.dequant = DeQuantStub()
    
    def forward(self, features: torch.Tensor, indices: torch.Tensor, batch_size: int):
        # x: [N, 28, 28, 1], must be NHWC tensor
        features = self.quant(features)
        # x_sp = spconv.SparseConvTensor.from_dense(x.reshape(-1, 28, 28, 1))
        x_sp = spconv.SparseConvTensor(features, indices, [28, 28], batch_size)
        # create SparseConvTensor manually: see SparseConvTensor.from_dense
        x_sp = self.net(x_sp)
        # print(x_sp.shape)
        x = x_sp
        x = torch.flatten(x, 1)
        x = self.dequant(x)
yanyan's avatar
yanyan committed
207
208
209
        output = F.log_softmax(x, dim=1)
        return output

yan.yan's avatar
yan.yan committed
210
211
212
213
214
215
216
217
class ResidualNetPTQ(nn.Module):
    """pytorch currently don't support cuda int8 inference, so
    we build a pure sparse network here.
    """
    def __init__(self):
        super(ResidualNetPTQ, self).__init__()
        self.net = spconv.SparseSequential(
            SubMConvBNReLU(1, 32, 3),
yan.yan's avatar
yan.yan committed
218
            # SubMConvBNReLU(32, 32, 3),
yan.yan's avatar
yan.yan committed
219
            SparseBasicBlock(32, 32),
yan.yan's avatar
yan.yan committed
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
            SubMConvBNReLU(32, 64, 3),
            SparseConvBNReLU(64, 64, 2, 2), # 14x14
            SparseConvBNReLU(64, 64, 2, 2), # 7x7
            SparseConvBNReLU(64, 64, 3, 2, 1), # 4x4
            spconv.SparseConv2d(64, 10, 4, 4),
            spconv.ToDense(),
        )
        # self.fc1 = nn.Linear(64 * 1 * 1, 128)
        # self.fc2 = nn.Linear(128, 10)
        # self.dropout1 = nn.Dropout2d(0.25)
        # self.dropout2 = nn.Dropout2d(0.5)

        self.quant = QuantStub()
        self.dequant = DeQuantStub()
    
    def forward(self, features: torch.Tensor, indices: torch.Tensor, batch_size: int):
        # x: [N, 28, 28, 1], must be NHWC tensor
        features = self.quant(features)
        # x_sp = spconv.SparseConvTensor.from_dense(x.reshape(-1, 28, 28, 1))
        x_sp = spconv.SparseConvTensor(features, indices, [28, 28], batch_size)
        # create SparseConvTensor manually: see SparseConvTensor.from_dense
        x_sp = self.net(x_sp)
        # print(x_sp.shape)
        x = x_sp
        x = torch.flatten(x, 1)
        x = self.dequant(x)
        output = F.log_softmax(x, dim=1)
        return output
yanyan's avatar
yanyan committed
248

yan.yan's avatar
yan.yan committed
249
250
251
252
253
254
255
256
257
class NetDense(nn.Module):
    def __init__(self):
        super(NetDense, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.dropout1 = nn.Dropout(0.25)
        self.dropout2 = nn.Dropout(0.5)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, 10)
yan.yan's avatar
yan.yan committed
258
259
        self.iden = spconv.SparseIdentity()

yan.yan's avatar
yan.yan committed
260
261
262
263
264
265
266
267
268
269
270
        self.quant = QuantStub()
        self.dequant = DeQuantStub()

    def forward(self, x):
        x = self.quant(x)

        x = self.conv1(x)

        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
yan.yan's avatar
yan.yan committed
271
        x = self.iden(x)
yan.yan's avatar
yan.yan committed
272
273
274
275
276
277
278
279
280
281
282
283
        x = F.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)
        x = self.dequant(x)

        output = F.log_softmax(x, dim=1)
        return output

yanyan's avatar
yanyan committed
284
285
def train(args, model, device, train_loader, optimizer, epoch):
    model.train()
yan.yan's avatar
v2.1  
yan.yan committed
286
    scaler = torch.cuda.amp.grad_scaler.GradScaler()
yan.yan's avatar
yan.yan committed
287
    amp_ctx = contextlib.nullcontext()
yan.yan's avatar
v2.1  
yan.yan committed
288
289
    if args.fp16:
        amp_ctx = torch.cuda.amp.autocast()
yanyan's avatar
yanyan committed
290
291
292
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
yan.yan's avatar
v2.1  
yan.yan committed
293
        with amp_ctx:
yan.yan's avatar
yan.yan committed
294
295
296
297
298
299
300
            if args.sparse:
                data_sp = spconv.SparseConvTensor.from_dense(data.reshape(-1, 28, 28, 1))
                # output = model(data_sp)
                output = model(data_sp.features, data_sp.indices, data_sp.batch_size)
            else:
                output = model(data)

yan.yan's avatar
v2.1  
yan.yan committed
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
            loss = F.nll_loss(output, target)
            scale = 1.0
            if args.fp16:
                assert loss.dtype is torch.float32
                scaler.scale(loss).backward()
                # scaler.step() first unscales the gradients of the optimizer's assigned params.
                # If these gradients do not contain infs or NaNs, optimizer.step() is then called,
                # otherwise, optimizer.step() is skipped.
                # scaler.unscale_(optim)

                # Since the gradients of optimizer's assigned params are now unscaled, clips as usual.
                # You may use the same value for max_norm here as you would without gradient scaling.
                # torch.nn.utils.clip_grad_norm_(models[0].net.parameters(), max_norm=0.1)

                scaler.step(optimizer)
                # Updates the scale for next iteration.
                scaler.update()
                scale = scaler.get_scale()
            else:
                loss.backward()
                optimizer.step()

yanyan's avatar
yanyan committed
323
324
325
326
327
328
        if batch_idx % args.log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))


yan.yan's avatar
v2.1  
yan.yan committed
329
def test(args, model, device, test_loader):
yanyan's avatar
yanyan committed
330
331
332
    model.eval()
    test_loss = 0
    correct = 0
yan.yan's avatar
yan.yan committed
333
    amp_ctx = contextlib.nullcontext()
yan.yan's avatar
v2.1  
yan.yan committed
334
335
336
    if args.fp16:
        amp_ctx = torch.cuda.amp.autocast()

yanyan's avatar
yanyan committed
337
338
    with torch.no_grad():
        for data, target in test_loader:
yan.yan's avatar
v2.1  
yan.yan committed
339

yanyan's avatar
yanyan committed
340
            data, target = data.to(device), target.to(device)
yan.yan's avatar
v2.1  
yan.yan committed
341
            with amp_ctx:
yan.yan's avatar
yan.yan committed
342
343
344
345
346
347
                if args.sparse:
                    data_sp = spconv.SparseConvTensor.from_dense(data.reshape(-1, 28, 28, 1))
                    # output = model(data_sp)
                    output = model(data_sp.features, data_sp.indices, data_sp.batch_size)
                else:
                    output = model(data)
348
349
350
351
352
            test_loss += F.nll_loss(
                output, target, reduction='sum').item()  # sum up batch loss
            pred = output.argmax(
                dim=1,
                keepdim=True)  # get the index of the max log-probability
yanyan's avatar
yanyan committed
353
354
355
356
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)

357
358
359
360
    print(
        '\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
            test_loss, correct, len(test_loader.dataset),
            100. * correct / len(test_loader.dataset)))
yanyan's avatar
yanyan committed
361
362


yan.yan's avatar
yan.yan committed
363
364
365
366
367
368
369
370
371
372
373
374
375
def calibrate(args, model: torch.nn.Module, data_loader, device):
    model.eval()
    
    with torch.no_grad():
        for image, target in data_loader:
            image = image.to(device)
            if args.sparse:
                data_sp = spconv.SparseConvTensor.from_dense(image.reshape(-1, 28, 28, 1))
                output = model(data_sp.features, data_sp.indices, data_sp.batch_size)
                # output = model(data_sp)
            else:
                output = model(image)

yanyan's avatar
yanyan committed
376
377
378
def main():
    # Training settings
    parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
379
380
381
382
    parser.add_argument('--batch-size',
                        type=int,
                        default=64,
                        metavar='N',
yanyan's avatar
yanyan committed
383
                        help='input batch size for training (default: 64)')
384
385
386
387
    parser.add_argument('--test-batch-size',
                        type=int,
                        default=1000,
                        metavar='N',
yanyan's avatar
yanyan committed
388
                        help='input batch size for testing (default: 1000)')
389
390
    parser.add_argument('--epochs',
                        type=int,
yan.yan's avatar
yan.yan committed
391
                        default=1,
392
                        metavar='N',
yanyan's avatar
yanyan committed
393
                        help='number of epochs to train (default: 14)')
394
395
396
397
    parser.add_argument('--lr',
                        type=float,
                        default=1.0,
                        metavar='LR',
yanyan's avatar
yanyan committed
398
                        help='learning rate (default: 1.0)')
399
400
401
402
    parser.add_argument('--gamma',
                        type=float,
                        default=0.7,
                        metavar='M',
yanyan's avatar
yanyan committed
403
                        help='Learning rate step gamma (default: 0.7)')
404
405
406
    parser.add_argument('--no-cuda',
                        action='store_true',
                        default=False,
yanyan's avatar
yanyan committed
407
                        help='disables CUDA training')
408
409
410
411
    parser.add_argument('--seed',
                        type=int,
                        default=1,
                        metavar='S',
yanyan's avatar
yanyan committed
412
                        help='random seed (default: 1)')
yan.yan's avatar
yan.yan committed
413
414
    parser.add_argument('--sparse',
                        action='store_true',
yan.yan's avatar
yan.yan committed
415
                        default=True,
yan.yan's avatar
yan.yan committed
416
                        help='use sparse conv network instead of dense')
417
418
419
420
421
422
423
424
425
426
    parser.add_argument(
        '--log-interval',
        type=int,
        default=10,
        metavar='N',
        help='how many batches to wait before logging training status')

    parser.add_argument('--save-model',
                        action='store_true',
                        default=False,
yanyan's avatar
yanyan committed
427
                        help='For Saving the current Model')
428
429
430
    parser.add_argument('--fp16',
                        action='store_true',
                        default=False,
yan.yan's avatar
v2.1  
yan.yan committed
431
432
                        help='For mixed precision training')

yanyan's avatar
yanyan committed
433
434
435
436
437
    args = parser.parse_args()
    use_cuda = not args.no_cuda and torch.cuda.is_available()

    torch.manual_seed(args.seed)

yan.yan's avatar
yan.yan committed
438
    device = torch.device("cuda" if use_cuda and args.sparse else "cpu")
yan.yan's avatar
yan.yan committed
439
    qdevice = torch.device("cuda" if use_cuda and args.sparse else "cpu")
yanyan's avatar
yanyan committed
440
    kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}
yan.yan's avatar
yan.yan committed
441
    if args.sparse:
yan.yan's avatar
yan.yan committed
442
        model = ResidualNetPTQ().to(device)
yan.yan's avatar
yan.yan committed
443
444
445
446
    else:
        model = NetDense().to(device)

    optimizer = optim.Adadelta(model.parameters(), lr=args.lr)
yanyan's avatar
yanyan committed
447
    train_loader = torch.utils.data.DataLoader(
448
449
450
451
452
453
454
455
456
457
458
459
        datasets.MNIST(
            '../data',
            train=True,
            download=True,
            transform=transforms.Compose([
                transforms.ToTensor(),
                # here we remove norm to get sparse tensor with lots of zeros
                # transforms.Normalize((0.1307,), (0.3081,))
            ])),
        batch_size=args.batch_size,
        shuffle=True,
        **kwargs)
yanyan's avatar
yanyan committed
460
    test_loader = torch.utils.data.DataLoader(
461
462
463
464
465
466
467
468
469
470
471
        datasets.MNIST(
            '../data',
            train=False,
            transform=transforms.Compose([
                transforms.ToTensor(),
                # here we remove norm to get sparse tensor with lots of zeros
                # transforms.Normalize((0.1307,), (0.3081,))
            ])),
        batch_size=args.test_batch_size,
        shuffle=True,
        **kwargs)
yanyan's avatar
yanyan committed
472
473
474
475

    scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma)
    for epoch in range(1, args.epochs + 1):
        train(args, model, device, train_loader, optimizer, epoch)
yan.yan's avatar
v2.1  
yan.yan committed
476
        test(args, model, device, test_loader)
yanyan's avatar
yanyan committed
477
        scheduler.step()
yan.yan's avatar
yan.yan committed
478
479
    if args.save_model:
        torch.save(model.state_dict(), "mnist_cnn.pt")
yan.yan's avatar
yan.yan committed
480
481
482
    model.eval()
    if not args.sparse:
        model = model.cpu()
yan.yan's avatar
yan.yan committed
483
484

    model_qat = copy.deepcopy(model)
yan.yan's avatar
yan.yan committed
485
486
    spconvq.prepare_spconv_torch_inference(True)
    # do qat
yan.yan's avatar
yan.yan committed
487

yan.yan's avatar
yan.yan committed
488
    qconfig_mapping_qat = get_default_spconv_qconfig_mapping(True)
yan.yan's avatar
yan.yan committed
489
490
491
    prepare_cfg = spconvq.get_spconv_prepare_custom_config()
    backend_cfg = spconvq.get_spconv_backend_config()

yan.yan's avatar
yan.yan committed
492
493
494
495
496
497
498
    prepared_model_qat = qfx.prepare_qat_fx(model_qat, qconfig_mapping_qat, (), backend_config=backend_cfg, prepare_custom_config=prepare_cfg)
    train(args, prepared_model_qat, qdevice, train_loader, optimizer, 1)
    converted_model = qfx.convert_fx(prepared_model_qat, qconfig_mapping=qconfig_mapping_qat, backend_config=backend_cfg)
    converted_model = spconvq.transform_qdq(converted_model)
    # test converted ptq model with int8 kernel
    spconvq.remove_conv_add_dq(converted_model)
    # you will see some nvrtc compile log here, which means int8 kernel is used.
yan.yan's avatar
yan.yan committed
499
500
    print(converted_model)
    test(args, converted_model, qdevice, test_loader)
yanyan's avatar
yanyan committed
501
502
503

if __name__ == '__main__':
    main()