benchmark_view_transformer.py 5.63 KB
Newer Older
lishj6's avatar
lishj6 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import time

import numpy as np
import torch
from mmcv import Config
from mmcv.parallel import MMDataParallel
from mmcv.runner import load_checkpoint

from mmdet3d.datasets import build_dataloader, build_dataset
from mmdet3d.models import build_detector


def parse_args():
    parser = argparse.ArgumentParser(description='MMDet benchmark a model')
    parser.add_argument('config', help='test config file path')
    parser.add_argument('checkpoint', help='checkpoint file')
    parser.add_argument('--samples', default=1000, help='samples to benchmark')
    parser.add_argument(
        '--log-interval', default=50, help='interval of logging')
    parser.add_argument(
        '--mem-only',
        action='store_true',
        help='Conduct the memory analysis only')
    parser.add_argument(
        '--no-acceleration',
        action='store_true',
        help='Omit the pre-computation acceleration')
    args = parser.parse_args()
    return args


def main():
    args = parse_args()

    cfg = Config.fromfile(args.config)
    # set cudnn_benchmark
    if cfg.get('cudnn_benchmark', False):
        torch.backends.cudnn.benchmark = True
    cfg.model.pretrained = None
    cfg.data.test.test_mode = True

    # build the dataloader
    # TODO: support multiple images per gpu (only minor changes are needed)
    dataset = build_dataset(cfg.data.test)
    data_loader = build_dataloader(
        dataset,
        samples_per_gpu=1,
        workers_per_gpu=cfg.data.workers_per_gpu,
        dist=False,
        shuffle=False)

    # build the model and load checkpoint
    if not args.no_acceleration:
        cfg.model.img_view_transformer.accelerate=True
    cfg.model.train_cfg = None
    assert cfg.model.type == 'BEVDet', \
        'Please use class BEVDet for ' \
        'view transformation inference ' \
        'speed estimation instead of %s'% cfg.model.type
    model = build_detector(cfg.model, test_cfg=cfg.get('test_cfg'))
    load_checkpoint(model, args.checkpoint, map_location='cpu')
    model = MMDataParallel(model, device_ids=[0])

    model.eval()

    # the first several iterations may be very slow so skip them
    num_warmup = 100
    pure_inf_time = 0
    D = model.module.img_view_transformer.D
    out_channels = model.module.img_view_transformer.out_channels
    depth_net = model.module.img_view_transformer.depth_net
    view_transformer = model.module.img_view_transformer
    # benchmark with several samples and take the average
    for i, data in enumerate(data_loader):

        with torch.no_grad():
            img_feat, _ = \
                model.module.image_encoder(data['img_inputs'][0][0].cuda())
            B, N, C, H, W = img_feat.shape
            x = depth_net(img_feat.reshape(B * N, C, H, W))
            depth_digit = x[:, :D, ...]
            tran_feat = x[:, D:D + out_channels, ...]
            depth = depth_digit.softmax(dim=1)
        input = [img_feat] + [d.cuda() for d in data['img_inputs'][0][1:]]

        if i == 0:
            precomputed_memory_allocated = 0.0
            if view_transformer.accelerate:
                start_mem_allocated = torch.cuda.memory_allocated()
                view_transformer.pre_compute(input)
                end_mem_allocated = torch.cuda.memory_allocated()
                precomputed_memory_allocated = \
                    end_mem_allocated - start_mem_allocated
                ref_max_mem_allocated = torch.cuda.max_memory_allocated()
                # occupy the memory
                size = (ref_max_mem_allocated - end_mem_allocated) // 4
                occupy_tensor = torch.zeros(
                    size=(size, ), device='cuda', dtype=torch.float32)
            print('Memory analysis: \n'
                  'precomputed_memory_allocated : %d B / %.01f MB \n' %
                  (precomputed_memory_allocated,
                   precomputed_memory_allocated / 1024 / 1024))
            start_mem_allocated = torch.cuda.memory_allocated()
            bev_feat = view_transformer.view_transform_core(
                input, depth, tran_feat)[0]
            end_max_mem_allocated = torch.cuda.max_memory_allocated()
            peak_memory_allocated = \
                end_max_mem_allocated - start_mem_allocated
            total_memory_requirement = \
                precomputed_memory_allocated + peak_memory_allocated
            print('Memory analysis: \n'
                  'Memory requirement : %d B / %.01f MB \n' %
                  (total_memory_requirement,
                   total_memory_requirement / 1024 / 1024))
            if args.mem_only:
                return

        torch.cuda.synchronize()
        start_time = time.perf_counter()
        with torch.no_grad():
            view_transformer.view_transform(input, depth, tran_feat)[0]
        torch.cuda.synchronize()
        elapsed = time.perf_counter() - start_time

        if i >= num_warmup:
            pure_inf_time += elapsed
            if (i + 1) % args.log_interval == 0:
                fps = (i + 1 - num_warmup) / pure_inf_time
                print(f'Done image [{i + 1:<3}/ {args.samples}], '
                      f'fps: {fps:.1f} img / s')

        if (i + 1) == args.samples:
            pure_inf_time += elapsed
            fps = (i + 1 - num_warmup) / pure_inf_time
            print(f'Overall fps: {fps:.1f} img / s')
            return fps


if __name__ == '__main__':
    repeat_times = 1
    fps_list = []
    for _ in range(repeat_times):
        fps = main()
        time.sleep(5)
        fps_list.append(fps)
    fps_list = np.array(fps_list, dtype=np.float32)
    print(f'Mean Overall fps: {fps_list.mean():.4f} +'
          f' {np.sqrt(fps_list.var()):.4f} img / s')