"tools/cfgs/nuscenes_models/cbgs_second_multihead.yaml" did not exist on "c1d93158891a044e00c7ff0d41873d89eea20fa9"
test.py 9.87 KB
Newer Older
PRC-Huang's avatar
PRC-Huang committed
1
2
3
4
5
6
# --------------------------------------------------------
# InternImage
# Copyright (c) 2022 OpenGVLab
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------

zhe chen's avatar
zhe chen committed
7
from __future__ import absolute_import, division, print_function
PRC-Huang's avatar
PRC-Huang committed
8

zhe chen's avatar
zhe chen committed
9
import math
PRC-Huang's avatar
PRC-Huang committed
10
import time
zhe chen's avatar
zhe chen committed
11

PRC-Huang's avatar
PRC-Huang committed
12
13
14
import torch
import torch.nn as nn
from functions.dcnv3_func import DCNv3Function, dcnv3_core_pytorch
zhe chen's avatar
zhe chen committed
15
from torch.autograd import gradcheck
PRC-Huang's avatar
PRC-Huang committed
16
17
18
19

H_in, W_in = 8, 8
N, M, D = 2, 4, 16
Kh, Kw = 3, 3
20
21
remove_center = False
P = Kh * Kw - remove_center
PRC-Huang's avatar
PRC-Huang committed
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
offset_scale = 2.0
pad = 1
dilation = 1
stride = 1
H_out = (H_in + 2 * pad - (dilation * (Kh - 1) + 1)) // stride + 1
W_out = (W_in + 2 * pad - (dilation * (Kw - 1) + 1)) // stride + 1

torch.manual_seed(3)


@torch.no_grad()
def check_forward_equal_with_pytorch_double():
    input = torch.rand(N, H_in, W_in, M*D).cuda() * 0.01
    offset = torch.rand(N, H_out, W_out, M*P*2).cuda() * 10
    mask = torch.rand(N, H_out, W_out, M, P).cuda() + 1e-5
    mask /= mask.sum(-1, keepdim=True)
    mask = mask.reshape(N, H_out, W_out, M*P)

    output_pytorch = dcnv3_core_pytorch(
        input.double(),
        offset.double(),
        mask.double(),
44
        Kh, Kw, stride, stride, Kh // 2, Kw // 2, dilation, dilation, M, D, offset_scale, remove_center).detach().cpu()
PRC-Huang's avatar
PRC-Huang committed
45
46
47
48
49
50
51

    im2col_step = 2
    output_cuda = DCNv3Function.apply(
        input.double(),
        offset.double(),
        mask.double(),
        Kh, Kw, stride, stride, Kh // 2, Kw // 2, dilation, dilation, M, D, offset_scale,
52
        im2col_step, remove_center).detach().cpu()
PRC-Huang's avatar
PRC-Huang committed
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73

    fwdok = torch.allclose(output_cuda, output_pytorch)
    max_abs_err = (output_cuda - output_pytorch).abs().max()
    max_rel_err = ((output_cuda - output_pytorch).abs() /
                   output_pytorch.abs()).max()
    print('>>> forward double')
    print(f'* {fwdok} check_forward_equal_with_pytorch_double: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}')


@torch.no_grad()
def check_forward_equal_with_pytorch_float():
    input = torch.rand(N, H_in, W_in, M*D).cuda() * 0.01
    offset = torch.rand(N, H_out, W_out, M*P*2).cuda() * 10
    mask = torch.rand(N, H_out, W_out, M, P).cuda() + 1e-5
    mask /= mask.sum(-1, keepdim=True)
    mask = mask.reshape(N, H_out, W_out, M*P)

    output_pytorch = dcnv3_core_pytorch(
        input,
        offset,
        mask,
74
        Kh, Kw, stride, stride, Kh // 2, Kw // 2, dilation, dilation, M, D, offset_scale, remove_center).detach().cpu()
PRC-Huang's avatar
PRC-Huang committed
75
76
77
78
79
80
81

    im2col_step = 2
    output_cuda = DCNv3Function.apply(
        input,
        offset,
        mask,
        Kh, Kw, stride, stride, Kh // 2, Kw // 2, dilation, dilation, M, D, offset_scale,
82
        im2col_step, remove_center).detach().cpu()
PRC-Huang's avatar
PRC-Huang committed
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112

    fwdok = torch.allclose(output_cuda, output_pytorch, rtol=1e-2, atol=1e-3)
    max_abs_err = (output_cuda - output_pytorch).abs().max()
    max_rel_err = ((output_cuda - output_pytorch).abs() /
                   output_pytorch.abs()).max()
    print('>>> forward float')
    print(f'* {fwdok} check_forward_equal_with_pytorch_float: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}')


def check_backward_equal_with_pytorch_double(channels=4, grad_input=True, grad_offset=True, grad_mask=True):
    # H_in, W_in = 4, 4
    N = 2
    M = 2
    H_out = (H_in + 2 * pad - (dilation * (Kh - 1) + 1)) // stride + 1
    W_out = (W_in + 2 * pad - (dilation * (Kw - 1) + 1)) // stride + 1

    D = channels
    input0 = torch.rand(N, H_in, W_in, M*D).cuda() * 0.01
    offset0 = torch.rand(N, H_out, W_out, M*P*2).cuda() * 10
    mask0 = torch.rand(N, H_out, W_out, M, P).cuda() + 1e-5
    mask0 /= mask0.sum(-1, keepdim=True)
    mask0 = mask0.reshape(N, H_out, W_out, M*P)
    input0.requires_grad = grad_input
    offset0.requires_grad = grad_offset
    mask0.requires_grad = grad_mask

    output_pytorch = dcnv3_core_pytorch(
        input0.double(),
        offset0.double(),
        mask0.double(),
113
        Kh, Kw, stride, stride, Kh // 2, Kw // 2, dilation, dilation, M, D, offset_scale, remove_center)
PRC-Huang's avatar
PRC-Huang committed
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
    output_pytorch.sum().backward()

    input1 = input0.detach()
    offset1 = offset0.detach()
    mask1 = mask0.detach()
    input1.requires_grad = grad_input
    offset1.requires_grad = grad_offset
    mask1.requires_grad = grad_mask

    im2col_step = 2
    output_cuda = DCNv3Function.apply(
        input1.double(),
        offset1.double(),
        mask1.double(),
        Kh, Kw, stride, stride, Kh // 2, Kw // 2, dilation, dilation, M, D, offset_scale,
129
        im2col_step, remove_center)
PRC-Huang's avatar
PRC-Huang committed
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
    output_cuda.sum().backward()

    print(f'>>> backward double: channels {D}')
    bwdok = torch.allclose(input0.grad, input1.grad, rtol=1e-2, atol=1e-3)
    max_abs_err = (input0.grad - input1.grad).abs().max()
    max_rel_err = ((input0.grad - input1.grad).abs() /
                   input0.grad.abs()).max()
    print(
        f'* {bwdok} input_grad check_backward_equal_with_pytorch_double: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}')

    bwdok = torch.allclose(offset0.grad, offset1.grad, rtol=1e-2, atol=1e-3)
    max_abs_err = (offset0.grad - offset1.grad).abs().max()
    max_rel_err = ((offset0.grad - offset1.grad).abs() /
                   offset0.grad.abs()).max()
    print(
        f'* {bwdok} offset_grad check_backward_equal_with_pytorch_double: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}')

    bwdok = torch.allclose(mask0.grad, mask1.grad, rtol=1e-2, atol=1e-3)
    max_abs_err = (mask0.grad - mask1.grad).abs().max()
    max_rel_err = ((mask0.grad - mask1.grad).abs() /
                   mask0.grad.abs()).max()
    print(
        f'* {bwdok} mask_grad check_backward_equal_with_pytorch_double: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}')


def check_backward_equal_with_pytorch_float(channels=4, grad_input=True, grad_offset=True, grad_mask=True):
    # H_in, W_in = 4, 4
    N = 2
    M = 2
    H_out = (H_in + 2 * pad - (dilation * (Kh - 1) + 1)) // stride + 1
    W_out = (W_in + 2 * pad - (dilation * (Kw - 1) + 1)) // stride + 1

    D = channels
    input0 = torch.rand(N, H_in, W_in, M*D).cuda() * 0.01
    offset0 = torch.rand(N, H_out, W_out, M*P*2).cuda() * 10
    mask0 = torch.rand(N, H_out, W_out, M, P).cuda() + 1e-5
    mask0 /= mask0.sum(-1, keepdim=True)
    mask0 = mask0.reshape(N, H_out, W_out, M*P)
    input0.requires_grad = grad_input
    offset0.requires_grad = grad_offset
    mask0.requires_grad = grad_mask

    output_pytorch = dcnv3_core_pytorch(
        input0,
        offset0,
        mask0,
176
        Kh, Kw, stride, stride, Kh // 2, Kw // 2, dilation, dilation, M, D, offset_scale, remove_center)
PRC-Huang's avatar
PRC-Huang committed
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
    output_pytorch.sum().backward()

    input1 = input0.detach()
    offset1 = offset0.detach()
    mask1 = mask0.detach()
    input1.requires_grad = grad_input
    offset1.requires_grad = grad_offset
    mask1.requires_grad = grad_mask

    im2col_step = 2
    output_cuda = DCNv3Function.apply(
        input1,
        offset1,
        mask1,
        Kh, Kw, stride, stride, Kh // 2, Kw // 2, dilation, dilation, M, D, offset_scale,
192
        im2col_step, remove_center)
PRC-Huang's avatar
PRC-Huang committed
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
    output_cuda.sum().backward()

    print(f'>>> backward float: channels {D}')
    bwdok = torch.allclose(input0.grad, input1.grad, rtol=1e-2, atol=1e-3)
    max_abs_err = (input0.grad - input1.grad).abs().max()
    max_rel_err = ((input0.grad - input1.grad).abs() /
                   input0.grad.abs()).max()
    print(
        f'* {bwdok} input_grad check_backward_equal_with_pytorch_float: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}')

    bwdok = torch.allclose(offset0.grad, offset1.grad, rtol=1e-2, atol=1e-3)
    max_abs_err = (offset0.grad - offset1.grad).abs().max()
    max_rel_err = ((offset0.grad - offset1.grad).abs() /
                   offset0.grad.abs()).max()
    print(
        f'* {bwdok} offset_grad check_backward_equal_with_pytorch_float: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}')

    bwdok = torch.allclose(mask0.grad, mask1.grad, rtol=1e-2, atol=1e-3)
    max_abs_err = (mask0.grad - mask1.grad).abs().max()
    max_rel_err = ((mask0.grad - mask1.grad).abs() /
                   mask0.grad.abs()).max()
    print(
        f'* {bwdok} mask_grad check_backward_equal_with_pytorch_float: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}')


@torch.no_grad()
def check_time_cost(im2col_step=128):
    N = 512
    H_in, W_in = 64, 64
    H_out = (H_in + 2 * pad - (dilation * (Kh - 1) + 1)) // stride + 1
    W_out = (W_in + 2 * pad - (dilation * (Kw - 1) + 1)) // stride + 1

    input = torch.rand(N, H_in, W_in, M*D).cuda() * 0.01
    offset = torch.rand(N, H_out, W_out, M*P*2).cuda() * 10
    mask = torch.rand(N, H_out, W_out, M, P).cuda() + 1e-5
    mask /= mask.sum(-1, keepdim=True)
    mask = mask.reshape(N, H_out, W_out, M*P)
    print(
        f'>>> time cost: im2col_step {im2col_step}; input {input.shape}; points {P} ')
    repeat = 100
    for i in range(repeat):
        output_cuda = DCNv3Function.apply(
            input,
            offset,
            mask,
            Kh, Kw, stride, stride, Kh // 2, Kw // 2, dilation, dilation, M, D, 1.0,
239
            im2col_step, remove_center)
PRC-Huang's avatar
PRC-Huang committed
240
241
242
243
244
245
246
247
    torch.cuda.synchronize()
    start = time.time()
    for i in range(repeat):
        output_cuda = DCNv3Function.apply(
            input,
            offset,
            mask,
            Kh, Kw, stride, stride, Kh // 2, Kw // 2, dilation, dilation, M, D, 1.0,
248
            im2col_step, remove_center)
PRC-Huang's avatar
PRC-Huang committed
249
250
251
252
253
254
255
256
257
258
259
260
261
262
    torch.cuda.synchronize()
    print(f'foward time cost: {(time.time() - start) / repeat}')


if __name__ == '__main__':
    check_forward_equal_with_pytorch_double()
    check_forward_equal_with_pytorch_float()
    for channels in [1, 16, 30, 32, 64, 71, 1025]:
        check_backward_equal_with_pytorch_double(channels, True, True, True)
    for channels in [1, 16, 30, 32, 64, 71, 1025]:
        check_backward_equal_with_pytorch_float(channels, True, True, True)
    for i in range(3):
        im2col_step = 128 * (2 ** i)
        check_time_cost(im2col_step)