unit_test.py 9.37 KB
Newer Older
unknown's avatar
unknown committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
# --------------------------------------------------------
# Fused kernel for window process for SwinTransformer
# Copyright (c) 2022 Nvidia
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------

import torch
import swin_window_process
import random
import time
import unittest


class WindowProcess(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input, B, H, W, C, shift_size, window_size):
        output = swin_window_process.roll_and_window_partition_forward(input, B, H, W, C, shift_size, window_size)

        ctx.B = B
        ctx.H = H
        ctx.W = W 
        ctx.C = C 
        ctx.shift_size = shift_size
        ctx.window_size = window_size
        return output

    @staticmethod
    def backward(ctx, grad_in):
        B = ctx.B
        H = ctx.H
        W = ctx.W 
        C = ctx.C 
        shift_size = ctx.shift_size
        window_size = ctx.window_size

        grad_out = swin_window_process.roll_and_window_partition_backward(grad_in, B, H, W, C, shift_size, window_size)
        return grad_out, None, None, None, None, None, None, None


class WindowProcessReverse(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input, B, H, W, C, shift_size, window_size):
        output = swin_window_process.window_merge_and_roll_forward(input, B, H, W, C, shift_size, window_size)

        ctx.B = B
        ctx.H = H
        ctx.W = W 
        ctx.C = C 
        ctx.shift_size = shift_size
        ctx.window_size = window_size

        return output

    @staticmethod
    def backward(ctx, grad_in):
        B = ctx.B
        H = ctx.H
        W = ctx.W 
        C = ctx.C 
        shift_size = ctx.shift_size
        window_size = ctx.window_size

        grad_out = swin_window_process.window_merge_and_roll_backward(grad_in, B, H, W, C, shift_size, window_size)
        return grad_out, None, None, None, None, None, None, None


def window_partition(x, window_size):
    """
    Args:
        x: (B, H, W, C)
        window_size (int): window size
    Returns:
        windows: (num_windows*B, window_size, window_size, C)
    """
    B, H, W, C = x.shape
    x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
    windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
    return windows

def window_reverse(windows, window_size, H, W):
    """
    Args:
        windows: (num_windows*B, window_size, window_size, C)
        window_size (int): Window size
        H (int): Height of image
        W (int): Width of image
    Returns:
        x: (B, H, W, C)
    """
    B = int(windows.shape[0] / (H * W / window_size / window_size))
    x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
    x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
    return x


def pyt_forward(x, shift_size, window_size):
    # x in shape(B, H, W, C)
    # cyclic shift
    if shift_size > 0:
        shifted_x = torch.roll(x, shifts=(-shift_size, -shift_size), dims=(1, 2))
    else:
        shifted_x = x
    # partition windows
    x_windows = window_partition(shifted_x, window_size)
    return x_windows


def reverse_pyt_forward(attn_windows, shift_size, window_size, H, W):
    # x in shape(B*nH*nW, window_size, window_size, C)
    shifted_x = window_reverse(attn_windows, window_size, H, W)
    if shift_size > 0:
        x = torch.roll(shifted_x, shifts=(shift_size, shift_size), dims=(1, 2))
    else:
        x = shifted_x
    return x


def copy_one_tensor(input, requires_grad=True):
    input1 = input.clone().detach().requires_grad_(requires_grad).cuda()
    return input1

class Test_WindowProcess(unittest.TestCase):
    def setUp(self):
        self.B = 192
        self.H = 56
        self.W = 56
        self.C = 96
        self.shift_size = 2
        self.window_size = 7
        self.nH = self.H // self.window_size
        self.nW = self.W // self.window_size
    
    def test_roll_and_window_partition_forward(self, dtype=torch.float32):
        input = torch.randn((self.B, self.H, self.W, self.C), dtype=dtype, requires_grad=True).cuda()
        
        input1 = copy_one_tensor(input, True)
        input2 = copy_one_tensor(input, True)

        with torch.no_grad():
            # ori
            expected = pyt_forward(input1, self.shift_size, self.window_size)
            # fused kernel
            fused_output = WindowProcess.apply(input2, self.B, self.H, self.W, self.C, -self.shift_size, self.window_size)
        
        self.assertTrue(torch.equal(expected, fused_output))
        #self.assertTrue(torch.allclose(expected, fused_output, rtol=1e-05, atol=1e-08))
    
    def test_roll_and_window_partition_backward(self, dtype=torch.float32):
        input = torch.randn((self.B, self.H, self.W, self.C), dtype=dtype, requires_grad=True).cuda()
        d_loss_tensor = torch.randn((self.B*self.nW*self.nH, self.window_size, self.window_size, self.C), dtype=dtype).cuda()
        
        input1 = copy_one_tensor(input, True)
        input2 = copy_one_tensor(input, True)

        # ori
        expected = pyt_forward(input1, self.shift_size, self.window_size)
        expected.backward(d_loss_tensor)
        # fused kernel
        fused_output = WindowProcess.apply(input2, self.B, self.H, self.W, self.C, -self.shift_size, self.window_size)
        fused_output.backward(d_loss_tensor)
        
        self.assertTrue(torch.equal(expected, fused_output))
        #self.assertTrue(torch.allclose(expected, fused_output, rtol=1e-05, atol=1e-08))

    def test_window_merge_and_roll_forward(self, dtype=torch.float32):
        input = torch.randn((self.B*self.nH*self.nW, self.window_size, self.window_size, self.C), dtype=dtype, requires_grad=True).cuda()
        
        input1 = copy_one_tensor(input, True)
        input2 = copy_one_tensor(input, True)

        with torch.no_grad():
            # ori
            expected = reverse_pyt_forward(input1, self.shift_size, self.window_size, self.H, self.W)
            # fused kernel
            fused_output = WindowProcessReverse.apply(input2, self.B, self.H, self.W, self.C, self.shift_size, self.window_size)
        
        self.assertTrue(torch.equal(expected, fused_output))
        #self.assertTrue(torch.allclose(expected, fused_output, rtol=1e-05, atol=1e-08))
    

    def test_window_merge_and_roll_backward(self, dtype=torch.float32):
        input = torch.randn((self.B*self.nH*self.nW, self.window_size, self.window_size, self.C), dtype=dtype, requires_grad=True).cuda()
        d_loss_tensor = torch.randn((self.B, self.H, self.W, self.C), dtype=dtype, requires_grad=True).cuda()
        
        input1 = copy_one_tensor(input, True)
        input2 = copy_one_tensor(input, True)

        # ori
        expected = reverse_pyt_forward(input1, self.shift_size, self.window_size, self.H, self.W)
        expected.backward(d_loss_tensor)
        # fused kernel
        fused_output = WindowProcessReverse.apply(input2, self.B, self.H, self.W, self.C, self.shift_size, self.window_size)
        fused_output.backward(d_loss_tensor)
        
        self.assertTrue(torch.equal(expected, fused_output))
        #self.assertTrue(torch.allclose(expected, fused_output, rtol=1e-05, atol=1e-08))

    def test_forward_backward_speed(self, dtype=torch.float32, times=1000):
        input = torch.randn((self.B*self.nH*self.nW, self.window_size, self.window_size, self.C), dtype=dtype, requires_grad=True).cuda()
        d_loss_tensor = torch.randn((self.B, self.H, self.W, self.C), dtype=dtype, requires_grad=True).cuda()
        
        input1 = copy_one_tensor(input, True)
        input2 = copy_one_tensor(input, True)

        # SwinTransformer official
        def run_pyt(t=1000):
            for _ in range(t):
                expected = reverse_pyt_forward(input1, self.shift_size, self.window_size, self.H, self.W)
                expected.backward(d_loss_tensor)

        # my op
        def run_fusedop(t=1000):
            for _ in range(t):
                fused_output = WindowProcessReverse.apply(input2, self.B, self.H, self.W, self.C, self.shift_size, self.window_size)
                fused_output.backward(d_loss_tensor)
        
        torch.cuda.synchronize()
        t1 = time.time()
        run_pyt(t=times)
        torch.cuda.synchronize()
        t2 = time.time()
        run_fusedop(t=times)
        torch.cuda.synchronize()
        t3 = time.time()
        self.assertTrue((t3 - t2) < (t2 - t1))

        print('Run {} times'.format(times))
        print('Original time cost: {}'.format(t2 - t1))
        print('Fused op time cost: {}'.format(t3 - t2))
    
    def test_roll_and_window_partition_forward_fp16(self, dtype=torch.float16):
        self.test_roll_and_window_partition_forward(dtype=dtype)

    def test_roll_and_window_partition_backward_fp16(self, dtype=torch.float16):
        self.test_roll_and_window_partition_backward(dtype=dtype)

    def test_window_merge_and_roll_forward_fp16(self, dtype=torch.float16):
        self.test_window_merge_and_roll_forward(dtype=dtype)
    
    def test_window_merge_and_roll_backward_fp16(self, dtype=torch.float16):
        self.test_window_merge_and_roll_backward(dtype=dtype)

    def test_forward_backward_speed_fp16(self, dtype=torch.float16, times=1000):
        self.test_forward_backward_speed(dtype=dtype, times=times)


if __name__ == '__main__':
    print('Pass only two tensors are exactly the same (using torch.equal).\n')
    torch.manual_seed(0)
    unittest.main(verbosity=2)