Commit 5b17e272 authored by wangkx1's avatar wangkx1
Browse files

init

parents
# --------------------------------------------------------
# InternImage
# Copyright (c) 2022 OpenGVLab
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------
from __future__ import absolute_import
from __future__ import print_function
from __future__ import division
import torch
import math
import torch.nn.functional as F
from torch.autograd import Function
from torch.autograd.function import once_differentiable
from torch.cuda.amp import custom_bwd, custom_fwd
from .table import TABLE, BWDTABLE
from DCNv4 import ext
def factors(N):
res = []
for i in range(1, N+1):
if N % i == 0:
res.append(i)
return res
def findspec(B, H, W, G, C):
key = f"{B}x{H}x{W}x{G}x{C}"
if key in TABLE:
return TABLE[key][0], TABLE[key][1]
d_stride = 8
ms = factors(B*H*W)
multiplier = 1
for m in ms:
if m <= 64 and (m * G * C // d_stride) <= 512:
multiplier = m
n_thread = multiplier * G * C // d_stride
key = f"{B}x{H}x{W}x{G}x{C}"
TABLE[key] = (d_stride, n_thread)
return d_stride, n_thread
def find_spec_bwd(B, H, W, G, C):
key = f"{B}x{H}x{W}x{G}x{C}"
if key in BWDTABLE:
return BWDTABLE[key][0], BWDTABLE[key][1]
if C >= 64:
d_stride = 2
else:
d_stride = 1
ms = factors(B*H*W)
multiplier = 1
for m in ms:
if m <= 64 and (m * G * C // d_stride) <= 256:
multiplier = m
n_thread = multiplier * G * C // d_stride
return d_stride, n_thread
class DCNv4Function(Function):
@staticmethod
@custom_fwd
def forward(
ctx, input, offset_mask,
kernel_h, kernel_w, stride_h, stride_w,
pad_h, pad_w, dilation_h, dilation_w,
group, group_channels, offset_scale,
im2col_step, remove_center):
forward_d_stride, forward_block_thread = findspec(input.shape[0], input.shape[1], input.shape[2], group, group_channels)
backward_d_stride, backward_block_thread = find_spec_bwd(input.shape[0], input.shape[1], input.shape[2], group, group_channels)
ctx.kernel_h = kernel_h
ctx.kernel_w = kernel_w
ctx.stride_h = stride_h
ctx.stride_w = stride_w
ctx.pad_h = pad_h
ctx.pad_w = pad_w
ctx.dilation_h = dilation_h
ctx.dilation_w = dilation_w
ctx.group = group
ctx.group_channels = group_channels
ctx.offset_scale = offset_scale
ctx.im2col_step = im2col_step
ctx.remove_center = remove_center
ctx.backward_d_stride = backward_d_stride
ctx.backward_block_thread = backward_block_thread
args = [
input, offset_mask, kernel_h,
kernel_w, stride_h, stride_w, pad_h,
pad_w, dilation_h, dilation_w, group,
group_channels, offset_scale,
ctx.im2col_step,
remove_center,
forward_d_stride,
forward_block_thread,
False,
]
output = ext.dcnv4_forward(*args)
ctx.save_for_backward(input, offset_mask)
return output
@staticmethod
@once_differentiable
@custom_bwd
def backward(ctx, grad_output):
input, offset_mask = ctx.saved_tensors
args = [
input, offset_mask, ctx.kernel_h,
ctx.kernel_w, ctx.stride_h, ctx.stride_w, ctx.pad_h,
ctx.pad_w, ctx.dilation_h, ctx.dilation_w, ctx.group,
ctx.group_channels, ctx.offset_scale, ctx.im2col_step,
grad_output.contiguous(), ctx.remove_center,
ctx.backward_d_stride, ctx.backward_block_thread,
False
]
grad_input, grad_offset_mask = \
ext.dcnv4_backward(*args)
return grad_input, grad_offset_mask, \
None, None, None, None, None, None, None,\
None, None, None, None, None, None
# ------------------------------------------------------------------------------------------------
# Deformable DETR
# Copyright (c) 2020 SenseTime. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------------------------------
# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
# ------------------------------------------------------------------------------------------------
from __future__ import absolute_import
from __future__ import print_function
from __future__ import division
import torch
import torch.nn.functional as F
from torch.autograd import Function
from torch.autograd.function import once_differentiable
import numpy as np
from DCNv4 import ext
shm_size_dict = {
"8.0": 163000,
"8.6": 99000,
"8.7": 163000,
"8.9": 99000,
"9.0": 227000,
"7.5": 64000,
"7.0": 96000,
}
cuda_capability = f"{torch.cuda.get_device_properties(0).major}.{torch.cuda.get_device_properties(0).minor}"
cuda_capability = "8.7"
if cuda_capability not in shm_size_dict:
raise NotImplementedError
shm_size_cap = shm_size_dict[cuda_capability]
def factors(N):
res = []
for i in range(1, N+1):
if N % i == 0:
res.append(i)
return res
def findspec(B, Q, G, C):
d_stride = 8
ms = factors(B*Q)
multiplier = 1
for m in ms:
if m <= 64 and (m * G * C // d_stride) <= 512:
multiplier = m
n_thread = multiplier * G * C // d_stride
return d_stride, n_thread
def findspec_bwd(B, Q, G, C):
if C >= 64:
d_stride = 2
else:
d_stride = 1
ms = factors(B*Q)
multiplier = 1
for m in ms:
if m <= 64 and (m * G * C // d_stride) <= 256:
multiplier = m
n_thread = multiplier * G * C // d_stride
return d_stride, n_thread
class FlashDeformAttnFunction(Function):
@staticmethod
@torch.autocast("cuda", enabled=True, dtype=torch.float16)
def forward(
ctx, value, value_spatial_shapes, value_level_start_index,
sampling_loc_attn, im2col_step, K=8
):
ctx.im2col_step = im2col_step
ctx.K = K
d_stride, blockthread = findspec(value.shape[0], sampling_loc_attn.shape[1], value.shape[2], value.shape[3])
d_stride_backward, blockthread_backward = findspec_bwd(value.shape[0], sampling_loc_attn.shape[1], value.shape[2], value.shape[3])
ctx.d_stride_backward = d_stride_backward
ctx.blockthread_backward = blockthread_backward
output = ext.flash_deform_attn_forward(
value,
value_spatial_shapes,
value_level_start_index,
sampling_loc_attn,
ctx.im2col_step,
K,
d_stride,
blockthread,
)
ctx.save_for_backward(value, value_spatial_shapes, value_level_start_index, sampling_loc_attn)
return output
@staticmethod
@once_differentiable
def backward(ctx, grad_output):
value, value_spatial_shapes, value_level_start_index, sampling_loc_attn = ctx.saved_tensors
grad_value, grad_sampling_loc_attn = ext.flash_deform_attn_backward(
value,
value_spatial_shapes,
value_level_start_index,
sampling_loc_attn,
grad_output.contiguous(),
ctx.im2col_step,
ctx.K,
ctx.d_stride_backward,
ctx.blockthread_backward,
)
return grad_value, None, None, grad_sampling_loc_attn, None, None
TABLE = {
"64x56x56x4x16": [
8,
448,
56
],
"64x28x28x4x16": [
8,
448,
56
],
"64x14x14x4x16": [
8,
32,
4
],
"64x7x7x4x16": [
8,
56,
7
],
"1x200x320x4x16": [
8,
32,
4
],
"1x100x160x4x16": [
8,
32,
4
],
"1x50x80x4x16": [
4,
512,
32
],
"1x25x40x4x16": [
4,
320,
20
],
"1x64x64x4x16": [
8,
512,
64
],
"64x56x56x5x16": [
8,
490,
49
],
"64x28x28x5x16": [
8,
490,
49
],
"64x14x14x5x16": [
8,
280,
28
],
"64x7x7x5x16": [
4,
140,
7
],
"1x200x320x5x16": [
4,
400,
20
],
"1x100x160x5x16": [
4,
400,
20
],
"1x50x80x5x16": [
8,
500,
50
],
"1x25x40x5x16": [
8,
20,
2
],
"1x64x64x5x16": [
8,
320,
32
],
"64x56x56x6x16": [
8,
768,
64
],
"64x28x28x6x16": [
8,
672,
56
],
"64x14x14x6x16": [
8,
336,
28
],
"64x7x7x6x16": [
8,
84,
7
],
"1x200x320x6x16": [
4,
600,
25
],
"1x100x160x6x16": [
4,
600,
25
],
"1x50x80x6x16": [
8,
600,
50
],
"1x25x40x6x16": [
2,
240,
5
],
"1x64x64x6x16": [
8,
384,
32
],
"64x56x56x7x16": [
8,
896,
64
],
"64x28x28x7x16": [
8,
686,
49
],
"64x14x14x7x16": [
8,
392,
28
],
"64x7x7x7x16": [
8,
686,
49
],
"1x200x320x7x16": [
8,
700,
50
],
"1x100x160x7x16": [
8,
700,
50
],
"1x50x80x7x16": [
8,
700,
50
],
"1x25x40x7x16": [
8,
70,
5
],
"1x64x64x7x16": [
8,
448,
32
],
"64x56x56x8x16": [
8,
448,
28
],
"64x28x28x8x16": [
8,
448,
28
],
"64x14x14x8x16": [
8,
448,
28
],
"64x7x7x8x16": [
8,
784,
49
],
"1x200x320x8x16": [
8,
800,
50
],
"1x100x160x8x16": [
4,
640,
20
],
"1x50x80x8x16": [
8,
800,
50
],
"1x25x40x8x16": [
4,
64,
2
],
"1x64x64x8x16": [
8,
256,
16
],
"64x56x56x4x32": [
8,
448,
28
],
"64x28x28x4x32": [
8,
448,
28
],
"64x14x14x4x32": [
8,
448,
28
],
"64x7x7x4x32": [
8,
112,
7
],
"1x200x320x4x32": [
8,
512,
32
],
"1x100x160x4x32": [
8,
800,
50
],
"1x50x80x4x32": [
8,
800,
50
],
"1x25x40x4x32": [
4,
128,
4
],
"1x64x64x4x32": [
8,
128,
8
],
"64x56x56x5x32": [
8,
560,
28
],
"64x28x28x5x32": [
8,
560,
28
],
"64x14x14x5x32": [
8,
560,
28
],
"64x7x7x5x32": [
8,
980,
49
],
"1x200x320x5x32": [
8,
500,
25
],
"1x100x160x5x32": [
8,
800,
40
],
"1x50x80x5x32": [
8,
1000,
50
],
"1x25x40x5x32": [
4,
200,
5
],
"1x64x64x5x32": [
8,
640,
32
],
"64x56x56x6x32": [
8,
336,
14
],
"64x28x28x6x32": [
8,
336,
14
],
"64x14x14x6x32": [
8,
336,
14
],
"64x7x7x6x32": [
16,
588,
49
],
"1x200x320x6x32": [
8,
480,
20
],
"1x100x160x6x32": [
8,
480,
20
],
"1x50x80x6x32": [
16,
600,
50
],
"1x25x40x6x32": [
8,
96,
4
],
"1x64x64x6x32": [
8,
768,
32
],
"64x56x56x7x32": [
8,
448,
16
],
"64x28x28x7x32": [
8,
448,
16
],
"64x14x14x7x32": [
8,
196,
7
],
"64x7x7x7x32": [
8,
28,
1
],
"1x200x320x7x32": [
8,
448,
16
],
"1x100x160x7x32": [
8,
448,
16
],
"1x50x80x7x32": [
8,
700,
25
],
"1x25x40x7x32": [
8,
56,
2
],
"1x64x64x7x32": [
8,
896,
32
],
"64x56x56x8x32": [
8,
448,
14
],
"64x28x28x8x32": [
8,
448,
14
],
"64x14x14x8x32": [
8,
448,
14
],
"64x7x7x8x32": [
8,
32,
1
],
"1x200x320x8x32": [
8,
512,
16
],
"1x100x160x8x32": [
8,
800,
25
],
"1x50x80x8x32": [
8,
800,
25
],
"1x25x40x8x32": [
4,
512,
8
],
"1x64x64x8x32": [
8,
32,
1
],
"64x56x56x4x64": [
8,
448,
14
],
"64x28x28x4x64": [
8,
448,
14
],
"64x14x14x4x64": [
8,
448,
14
],
"64x7x7x4x64": [
8,
32,
1
],
"1x200x320x4x64": [
8,
512,
16
],
"1x100x160x4x64": [
8,
512,
16
],
"1x50x80x4x64": [
8,
800,
25
],
"1x25x40x4x64": [
8,
640,
20
],
"1x64x64x4x64": [
8,
512,
16
],
"64x56x56x5x64": [
8,
560,
14
],
"64x28x28x5x64": [
8,
560,
14
],
"64x14x14x5x64": [
8,
560,
14
],
"64x7x7x5x64": [
8,
280,
7
],
"1x200x320x5x64": [
8,
800,
20
],
"1x100x160x5x64": [
8,
800,
20
],
"1x50x80x5x64": [
8,
1000,
25
],
"1x25x40x5x64": [
8,
80,
2
],
"1x64x64x5x64": [
8,
320,
8
],
"64x56x56x6x64": [
8,
768,
16
],
"64x28x28x6x64": [
8,
768,
16
],
"64x14x14x6x64": [
8,
336,
7
],
"64x7x7x6x64": [
8,
336,
7
],
"1x200x320x6x64": [
8,
768,
16
],
"1x100x160x6x64": [
8,
480,
10
],
"1x50x80x6x64": [
16,
240,
10
],
"1x25x40x6x64": [
8,
240,
5
],
"1x64x64x6x64": [
8,
768,
16
],
"64x56x56x7x64": [
8,
896,
16
],
"64x28x28x7x64": [
8,
448,
8
],
"64x14x14x7x64": [
8,
392,
7
],
"64x7x7x7x64": [
8,
56,
1
],
"1x200x320x7x64": [
8,
896,
16
],
"1x100x160x7x64": [
8,
448,
8
],
"1x50x80x7x64": [
8,
448,
8
],
"1x25x40x7x64": [
8,
448,
8
],
"1x64x64x7x64": [
8,
448,
8
],
"64x56x56x8x64": [
8,
896,
14
],
"64x28x28x8x64": [
8,
896,
14
],
"64x14x14x8x64": [
8,
448,
7
],
"64x7x7x8x64": [
8,
64,
1
],
"1x200x320x8x64": [
8,
512,
8
],
"1x100x160x8x64": [
8,
512,
8
],
"1x50x80x8x64": [
8,
512,
8
],
"1x25x40x8x64": [
8,
512,
8
],
"1x64x64x8x64": [
8,
512,
8
]
}
BWDTABLE = {
"64x56x56x4x16": [
1,
256,
4
],
"64x56x56x5x16": [
1,
320,
4
],
"64x56x56x6x16": [
1,
192,
2
],
"64x56x56x7x16": [
1,
224,
2
],
"64x56x56x8x16": [
1,
256,
2
],
"64x56x56x4x32": [
1,
256,
2
],
"64x56x56x5x32": [
1,
160,
1
],
"64x56x56x6x32": [
1,
192,
1
],
"64x56x56x7x32": [
1,
224,
1
],
"64x56x56x8x32": [
1,
256,
1
],
"64x56x56x4x64": [
2,
512,
4
],
"64x56x56x5x64": [
2,
640,
4
],
"64x56x56x6x64": [
2,
384,
2
],
"64x56x56x7x64": [
2,
224,
1
],
"64x56x56x8x64": [
2,
1024,
4
],
"64x28x28x4x16": [
1,
128,
2
],
"64x28x28x5x16": [
1,
320,
4
],
"64x28x28x6x16": [
1,
96,
1
],
"64x28x28x7x16": [
1,
224,
2
],
"64x28x28x8x16": [
1,
128,
1
],
"64x28x28x4x32": [
1,
128,
1
],
"64x28x28x5x32": [
1,
320,
2
],
"64x28x28x6x32": [
1,
192,
1
],
"64x28x28x7x32": [
1,
224,
1
],
"64x28x28x8x32": [
1,
256,
1
],
"64x28x28x4x64": [
2,
512,
4
],
"64x28x28x5x64": [
2,
640,
4
],
"64x28x28x6x64": [
2,
384,
2
],
"64x28x28x7x64": [
2,
224,
1
],
"64x28x28x8x64": [
2,
512,
2
],
"64x14x14x4x16": [
1,
128,
2
],
"64x14x14x5x16": [
1,
320,
4
],
"64x14x14x6x16": [
1,
192,
2
],
"64x14x14x7x16": [
1,
224,
2
],
"64x14x14x8x16": [
1,
128,
1
],
"64x14x14x4x32": [
1,
256,
2
],
"64x14x14x5x32": [
1,
160,
1
],
"64x14x14x6x32": [
1,
192,
1
],
"64x14x14x7x32": [
1,
224,
1
],
"64x14x14x8x32": [
1,
256,
1
],
"64x14x14x4x64": [
2,
128,
1
],
"64x14x14x5x64": [
2,
160,
1
],
"64x14x14x6x64": [
2,
384,
2
],
"64x14x14x7x64": [
2,
224,
1
],
"64x14x14x8x64": [
2,
256,
1
],
"64x7x7x4x16": [
4,
784,
49
],
"64x7x7x5x16": [
2,
280,
7
],
"64x7x7x6x16": [
2,
48,
1
],
"64x7x7x7x16": [
2,
392,
7
],
"64x7x7x8x16": [
1,
128,
1
],
"64x7x7x4x32": [
1,
128,
1
],
"64x7x7x5x32": [
1,
160,
1
],
"64x7x7x6x32": [
2,
96,
1
],
"64x7x7x7x32": [
2,
112,
1
],
"64x7x7x8x32": [
2,
128,
1
],
"64x7x7x4x64": [
2,
896,
7
],
"64x7x7x5x64": [
2,
160,
1
],
"64x7x7x6x64": [
2,
192,
1
],
"64x7x7x7x64": [
2,
224,
1
],
"64x7x7x8x64": [
2,
256,
1
],
"1x200x320x4x16": [
1,
320,
5
],
"1x200x320x5x16": [
1,
320,
4
],
"1x200x320x6x16": [
1,
96,
1
],
"1x200x320x7x16": [
1,
224,
2
],
"1x200x320x8x16": [
1,
640,
5
],
"1x200x320x4x32": [
1,
128,
1
],
"1x200x320x5x32": [
1,
320,
2
],
"1x200x320x6x32": [
1,
384,
2
],
"1x200x320x7x32": [
1,
224,
1
],
"1x200x320x8x32": [
1,
256,
1
],
"1x200x320x4x64": [
2,
640,
5
],
"1x200x320x5x64": [
2,
800,
5
],
"1x200x320x6x64": [
2,
768,
4
],
"1x200x320x7x64": [
2,
448,
2
],
"1x200x320x8x64": [
2,
1024,
4
],
"1x100x160x4x16": [
1,
320,
5
],
"1x100x160x5x16": [
1,
640,
8
],
"1x100x160x6x16": [
1,
96,
1
],
"1x100x160x7x16": [
1,
224,
2
],
"1x100x160x8x16": [
1,
640,
5
],
"1x100x160x4x32": [
1,
256,
2
],
"1x100x160x5x32": [
1,
160,
1
],
"1x100x160x6x32": [
1,
384,
2
],
"1x100x160x7x32": [
1,
224,
1
],
"1x100x160x8x32": [
1,
512,
2
],
"1x100x160x4x64": [
2,
128,
1
],
"1x100x160x5x64": [
2,
160,
1
],
"1x100x160x6x64": [
2,
384,
2
],
"1x100x160x7x64": [
2,
448,
2
],
"1x100x160x8x64": [
2,
512,
2
],
"1x50x80x4x16": [
1,
320,
5
],
"1x50x80x5x16": [
1,
320,
4
],
"1x50x80x6x16": [
1,
96,
1
],
"1x50x80x7x16": [
1,
112,
1
],
"1x50x80x8x16": [
1,
512,
4
],
"1x50x80x4x32": [
1,
128,
1
],
"1x50x80x5x32": [
1,
320,
2
],
"1x50x80x6x32": [
1,
384,
2
],
"1x50x80x7x32": [
1,
224,
1
],
"1x50x80x8x32": [
1,
256,
1
],
"1x50x80x4x64": [
2,
256,
2
],
"1x50x80x5x64": [
2,
640,
4
],
"1x50x80x6x64": [
2,
768,
4
],
"1x50x80x7x64": [
2,
448,
2
],
"1x50x80x8x64": [
2,
1024,
4
],
"1x25x40x4x16": [
1,
320,
5
],
"1x25x40x5x16": [
2,
400,
10
],
"1x25x40x6x16": [
1,
192,
2
],
"1x25x40x7x16": [
4,
224,
8
],
"1x25x40x8x16": [
4,
160,
5
],
"1x25x40x4x32": [
2,
128,
2
],
"1x25x40x5x32": [
1,
320,
2
],
"1x25x40x6x32": [
2,
96,
1
],
"1x25x40x7x32": [
2,
112,
1
],
"1x25x40x8x32": [
2,
640,
5
],
"1x25x40x4x64": [
2,
128,
1
],
"1x25x40x5x64": [
2,
160,
1
],
"1x25x40x6x64": [
2,
192,
1
],
"1x25x40x7x64": [
2,
896,
4
],
"1x25x40x8x64": [
2,
512,
2
],
"1x64x64x4x16": [
1,
256,
4
],
"1x64x64x5x16": [
2,
40,
1
],
"1x64x64x6x16": [
1,
192,
2
],
"1x64x64x7x16": [
1,
224,
2
],
"1x64x64x8x16": [
1,
512,
4
],
"1x64x64x4x32": [
2,
64,
1
],
"1x64x64x5x32": [
1,
320,
2
],
"1x64x64x6x32": [
1,
192,
1
],
"1x64x64x7x32": [
1,
224,
1
],
"1x64x64x8x32": [
1,
256,
1
],
"1x64x64x4x64": [
2,
512,
4
],
"1x64x64x5x64": [
2,
640,
4
],
"1x64x64x6x64": [
2,
192,
1
],
"1x64x64x7x64": [
2,
224,
1
],
"1x64x64x8x64": [
2,
256,
1
]
}
\ No newline at end of file
python search_dcnv4_bwd_engine.py > res_bwd.txt
python find_best.py --input res_bwd.txt --output table_bwd.py
\ No newline at end of file
from __future__ import absolute_import
from __future__ import print_function
from __future__ import division
import time
import math
import torch
import torch.nn as nn
import math
from torch.autograd import gradcheck
import pandas as pd
from easydict import EasyDict as edict
import argparse
from torch.cuda import Event
from functions.dcnv3_func import DCNv3Function, dcnv3_core_pytorch
from functions.dcnv4_func import DCNv4Function
torch.set_printoptions(threshold=10000)
torch.manual_seed(3)
#@torch.no_grad()
def speed_test(func, args, inputs, name='Unknown'):
tic = Event(enable_timing=True)
toc = Event(enable_timing=True)
# warmup
for i in range(args.warmup_num):
func(*inputs)
total_time = 0
tic.record()
for i in range(args.test_num):
o = func(*inputs)
torch.cuda.synchronize()
toc.record()
avg_time = tic.elapsed_time(toc) / args.test_num
# print(
# f'>>> {name: <10} finished {args.test_num} running, avg_time: {avg_time:.6f} ms')
return avg_time
@torch.no_grad()
def test(N, H_in, W_in, M, D, spec=None):
Kh, Kw = 3, 3
remove_center = False
P = Kh * Kw - remove_center
offset_scale = 2.0
pad = 1
dilation = 1
stride = 1
H_out = (H_in + 2 * pad - (dilation * (Kh - 1) + 1)) // stride + 1
W_out = (W_in + 2 * pad - (dilation * (Kw - 1) + 1)) // stride + 1
input = torch.rand(N, H_in, W_in, M*D).cuda()
# print(input.shape)
offset = (torch.rand(N, H_out, W_out, M*P*2).cuda() * 2 - 1)*2
# offset = (torch.rand(N, H_out, W_out, M*P*2).cuda() * 2 - 1)*0
mask_origin = torch.rand(N, H_out, W_out, M, P).cuda() + 1e-5
mask_origin = mask_origin.half()
mask = mask_origin
# mask = torch.nn.functional.softmax(mask_origin, dim=-1)
offset_mask = torch.cat([offset.unflatten(-1, (M, P * 2)), mask_origin.detach()], dim=-1).flatten(-2)
im2col_step = 128
input = input.half()
offset = offset.half()
mask = mask.half()
offset_mask = offset_mask.half()
dcnv3_args = [
input,
offset,
mask,
Kh, Kw, stride, stride, Kh // 2, Kw // 2, dilation, dilation, M, D, offset_scale,
im2col_step, remove_center,
]
output_pytorch = DCNv3Function.apply(*dcnv3_args)
input1 = input.detach()
def pad(om):
padded_zero = int(math.ceil(om.shape[3]/8)*8) - om.shape[3]
padded = torch.zeros(om.shape[0], om.shape[1], om.shape[2], padded_zero).to(om)
return torch.cat([om, padded], dim=-1)
dcnv4_args = [
input1, pad(offset_mask),
Kh, Kw, stride, stride, Kh // 2, Kw // 2, dilation, dilation, M, D, offset_scale,
im2col_step, remove_center,
spec[0], spec[1], 2, None
# 8, 512, 2, 256
]
output_flash_cuda = DCNv4Function.apply(*dcnv4_args)
fwdok = torch.allclose(output_flash_cuda, output_pytorch, rtol=1e-2, atol=1e-3)
max_abs_err = (output_flash_cuda - output_pytorch).abs().max()
max_rel_err = ((output_flash_cuda - output_pytorch).abs() /
(output_pytorch.abs()+ 1e-3)).max()
# print('>>> forward half')
# print(f'* {fwdok} check_forward_equal_with_pytorch_float: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}')
if not fwdok:
print(f"Wrong: {N}x{H_in}x{W_in}x{M}x{D} \t {spec[0]}/{spec[1]}({spec[2]})")
return
# assert(fwdok)
test_args = edict({'warmup_num': 10000, 'test_num': 10000})
exp_time_dcnv4 = speed_test(DCNv4Function.apply, test_args, dcnv4_args, name='exp')
torch.cuda.synchronize()
print(f"{N}x{H_in}x{W_in}x{M}x{D} \t {spec[0]}/{spec[1]}({spec[2]}): {exp_time_dcnv4}")
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--n", type=int)
parser.add_argument("--h", type=int)
parser.add_argument("--w", type=int)
parser.add_argument("--g", type=int)
parser.add_argument("--c", type=int)
parser.add_argument("--dstride", type=int)
parser.add_argument("--blockthread", type=int)
parser.add_argument("--multiplier", type=int)
args = parser.parse_args()
test(args.n, args.h, args.w, args.g, args.c, (args.dstride, args.blockthread, args.multiplier))
# --------------------------------------------------------
# DCNv4
# Copyright (c) 2024 OpenGVLab
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------
from __future__ import absolute_import
from __future__ import print_function
from __future__ import division
import time
import torch
import torch.nn as nn
import math
from torch.autograd import gradcheck
import pandas as pd
from easydict import EasyDict as edict
import argparse
from torch.cuda import Event
from functions import DCNv4Function, DCNv3Function
torch.set_printoptions(threshold=10000)
torch.manual_seed(3)
def speed_test_backward(func, args, inputs, name='Unknown'):
# warmup
# for i in range(args.warmup_num):
# o = func(*inputs)
# o.sum().backward()
total_time = 0
len_input = len(inputs)
for i in range(args.warmup_num + args.test_num):
tic = Event(enable_timing=True)
toc = Event(enable_timing=True)
inputs[0] = inputs[0].detach()
inputs[0].requires_grad = True
if len_input > 1 and isinstance(inputs[1], torch.Tensor):
inputs[1] = inputs[1].detach()
inputs[1].requires_grad = True
if len_input > 2 and isinstance(inputs[2], torch.Tensor):
inputs[2] = inputs[2].detach()
inputs[2].requires_grad = True
o = func(*inputs)
torch.cuda.synchronize()
tic.record()
o.sum().backward()
toc.record()
torch.cuda.synchronize()
_time = tic.elapsed_time(toc)
if i >= args.warmup_num:
total_time += _time
o = o.detach()
# toc.record()
# torch.cuda.synchronize()
avg_time = total_time / args.test_num
#print(
# f'>>> {name: <10} finished {args.test_num} running, avg_time: {avg_time:.6f} ms')
return avg_time
# @torch.no_grad()
def test(N=64, H_in=32, W_in=32, M=4, D=16, spec=None):
"""
64x56x56x128(G=4)
2 64: 3.66
- offset_mask collection write 3.4022
- offset_mask collection 3.1968
"""
Kh, Kw = 3, 3
remove_center = False
P = Kh * Kw - remove_center
offset_scale = 2.0
pad = 1
dilation = 1
stride = 1
H_out = (H_in + 2 * pad - (dilation * (Kh - 1) + 1)) // stride + 1
W_out = (W_in + 2 * pad - (dilation * (Kw - 1) + 1)) // stride + 1
additions = [None, None, spec[0], spec[1], False]
input = torch.rand(N, H_in, W_in, M*D).cuda() * 10
#offset = torch.rand(N, H_out, W_out, M*P*2).cuda() * 0
offset = (torch.rand(N, H_out, W_out, M*P*2).cuda() * 2 - 1)*2
mask_origin = torch.rand(N, H_out, W_out, M, P).cuda() + 1e-5
mask_origin = mask_origin.half()
mask_origin.requires_grad = True
# offset_mask = torch.cat([offset.unflatten(-1, (M, P, 2)), mask_origin.detach().unsqueeze(-1)], dim=-1).flatten(-3)
# mask /= mask.sum(-1, keepdim=True)
# mask = torch.nn.functional.softmax(mask_origin, dim=-1, dtype=torch.float32)
mask = mask_origin
# mask = mask.reshape(N, H_out, W_out, M*P)
# offset_mask = torch.cat([offset.unflatten(-1, (M, P, 2)), mask.detach().unsqueeze(-1)], dim=-1).flatten(-3)
offset_mask = torch.cat([offset.detach().unflatten(-1, (M, P * 2)), mask_origin.detach()], dim=-1).flatten(-2)
im2col_step = 128
input = input.half()
offset = offset.half()
mask = mask.half()
input.requires_grad = True
offset.requires_grad = True
# mask.requires_grad = True
output_pytorch = DCNv3Function.apply(
input,
offset,
mask,
Kh, Kw, stride, stride, Kh // 2, Kw // 2, dilation, dilation, M, D, offset_scale,
im2col_step, remove_center)#.detach().cpu()
(output_pytorch.sum()/10).backward()
def pad(om):
padded_zero = int(math.ceil(om.shape[3]/8)*8) - om.shape[3]
padded = torch.zeros(om.shape[0], om.shape[1], om.shape[2], padded_zero).to(om)
return torch.cat([om, padded], dim=-1)
# value_offset_mask = input.detach()
input1 = input.detach()
input1.requires_grad = True
offset_mask = offset_mask.half()
offset_mask.requires_grad = True
# offset_mask1.requires_grad = True
torch.cuda.profiler.cudart().cudaProfilerStart()
output_flash_cuda = DCNv4Function.apply(
input1, offset_mask,
Kh, Kw, stride, stride, Kh // 2, Kw // 2, dilation, dilation, M, D, offset_scale,
im2col_step, remove_center, *additions)#.detach().cpu()
(output_flash_cuda.sum()/10).backward()
torch.cuda.profiler.cudart().cudaProfilerStop()
input_grad = input.grad
input2_grad = input1.grad
bwdok = torch.allclose(input_grad.float(), input2_grad.float(), rtol=1e-2, atol=1e-3)
rel_err = (input_grad.abs() - input2_grad.abs())/(input_grad.abs()+1e-3)
offset_grad1 = offset.grad
offset_grad2 = offset_mask.grad.reshape(N, H_out, W_out, M, P*3)[..., :P*2].reshape(N, H_out, W_out, M*P*2)
bwdok2 = torch.allclose(offset_grad1.float(), offset_grad2.float(), rtol=1e-2, atol=1e-3)
rel_err = (offset_grad1 - offset_grad2).abs() / (offset_grad1.abs()+1e-3)
mask_grad1 = mask_origin.grad
mask_grad2 = offset_mask.grad.reshape(N, H_out, W_out, M, P*3)[..., P*2:].reshape(N, H_out, W_out, M, P)
bwdok3 = torch.allclose(mask_grad1, mask_grad2, rtol=1e-2, atol=1e-3)
rel_err = (mask_grad1 - mask_grad2).abs() / (mask_grad1.abs()+1e-3)
fwdok = torch.allclose(output_flash_cuda, output_pytorch, rtol=1e-2, atol=1e-3)
max_abs_err = (output_flash_cuda - output_pytorch).abs().max()
max_rel_err = ((output_flash_cuda - output_pytorch).abs() /
(output_pytorch.abs()+ 1e-3)).max()
if not (bwdok and bwdok2 and bwdok3):
print(f"Wrong: {N}x{H_in}x{W_in}x{M}x{D} \t {spec[0]}/{spec[1]}({spec[2]})")
return
# fn_args = [
# input,
# offset,
# mask,
# Kh, Kw, stride, stride, Kh // 2, Kw // 2, dilation, dilation, M, D, offset_scale,
# im2col_step, remove_center
# ]
flash_dcn_fn_args = [
input1,
offset_mask,
Kh, Kw, stride, stride, Kh // 2, Kw // 2, dilation, dilation, M, D, offset_scale,
im2col_step, remove_center, *additions
]
test_args = edict({'warmup_num': 1000, 'test_num': 1000})
try:
exp_time = speed_test_backward(DCNv4Function.apply, test_args, flash_dcn_fn_args, name='exp')
except:
print(f"Wrong: {N}x{H_in}x{W_in}x{M}x{D} \t {spec[0]}/{spec[1]}({spec[2]})")
return
torch.cuda.synchronize()
print(f"{N}x{H_in}x{W_in}x{M}x{D} \t {spec[0]}/{spec[1]}({spec[2]}): {exp_time}")
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--n", type=int)
parser.add_argument("--h", type=int)
parser.add_argument("--w", type=int)
parser.add_argument("--g", type=int)
parser.add_argument("--c", type=int)
parser.add_argument("--dstride", type=int)
parser.add_argument("--blockthread", type=int)
parser.add_argument("--multiplier", type=int)
args = parser.parse_args()
test(args.n, args.h, args.w, args.g, args.c, (args.dstride, args.blockthread, args.multiplier))
import os
def factors(N):
res = []
for i in range(1, N+1):
if N % i == 0:
res.append(i)
return res
if __name__ == '__main__':
BATCH=64
for N, Hin, Win in [(BATCH, 56, 56), (BATCH, 28, 28), (BATCH, 14, 14), (BATCH, 7, 7),
(1, 200, 320), (1, 100, 160), (1, 50, 80), (1, 25, 40), (1, 64, 64)]:
for group_channel in [16, 32, 64]:
for group in [4, 5, 6, 7, 8]:
for d_stride in [1, 2, 4]:
for m in factors(N*Hin*Win):
if m > 64:
break
block_thread = group * (group_channel//d_stride) * m
if block_thread > 1024:
break
cmd = f"python search_dcnv4_bwd.py --n {N} --h {Hin} --w {Win} --g {group} --c {group_channel} --dstride {d_stride} --blockthread {block_thread} --multiplier {m}"
os.system(cmd)
\ No newline at end of file
import os
def factors(N):
res = []
for i in range(1, N+1):
if N % i == 0:
res.append(i)
return res
if __name__ == '__main__':
BATCH=64
for group_channel in [16, 32, 64]:
for group in [4, 5, 6, 7, 8]:
for N, Hin, Win in [(BATCH, 56, 56), (BATCH, 28, 28), (BATCH, 14, 14), (BATCH, 7, 7),
(1, 200, 320), (1, 100, 160), (1, 50, 80), (1, 25, 40), (1, 64, 64)]:
for d_stride in [2, 4, 8, 16]:
for m in factors(N*Hin*Win):
if m > 64:
break
block_thread = group * (group_channel//d_stride) * m
if block_thread > 1024:
break
cmd = f"python search_dcnv4.py --n {N} --h {Hin} --w {Win} --g {group} --c {group_channel} --dstride {d_stride} --blockthread {block_thread} --multiplier {m}"
os.system(cmd)
\ No newline at end of file
python search_dcnv4_engine.py > res.txt
python find_best.py --input res.txt --output table.py
\ No newline at end of file
# --------------------------------------------------------
# DCNv4
# Copyright (c) 2024 OpenGVLab
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------
from __future__ import absolute_import
from __future__ import print_function
from __future__ import division
import time
import torch
import torch.nn as nn
import math
from torch.autograd import gradcheck
import pandas as pd
from easydict import EasyDict as edict
from torch.cuda import Event
# from functions.dcnv3_func import DCNv3Function, dcnv3_core_pytorch
from functions.dcnv4_func import DCNv4Function
torch.set_printoptions(threshold=10000)
H_in, W_in = 56, 56
N, M, D = 64, 4, 32
# H_in, W_in = 28, 28
# N, M, D = 64, 8, 32
# H_in, W_in = 14, 14
# N, M, D = 64, 16, 32
# H_in, W_in = 7, 7
# N, M, D = 64, 32, 32
# H_in, W_in = 8, 8
# N, M, D = 128, 4, 16
Kh, Kw = 3, 3
remove_center = False
P = Kh * Kw - remove_center
offset_scale = 2.0
pad = 1
dilation = 1
stride = 1
H_out = (H_in + 2 * pad - (dilation * (Kh - 1) + 1)) // stride + 1
W_out = (W_in + 2 * pad - (dilation * (Kw - 1) + 1)) // stride + 1
torch.manual_seed(3)
#@torch.no_grad()
def speed_test(func, args, inputs, name='Unknown'):
tic = Event(enable_timing=True)
toc = Event(enable_timing=True)
# warmup
for i in range(args.warmup_num):
func(*inputs)
total_time = 0
tic.record()
for i in range(args.test_num):
o = func(*inputs)
torch.cuda.synchronize()
toc.record()
avg_time = tic.elapsed_time(toc) / args.test_num
print(
f'>>> {name: <10} finished {args.test_num} running, avg_time: {avg_time:.6f} ms')
return avg_time
@torch.no_grad()
def check_forward_equal_with_pytorch_half():
input = torch.rand(N, H_in, W_in, M*D).cuda()
print(input.shape)
offset = (torch.rand(N, H_out, W_out, M*P*2).cuda() * 2 - 1)*10
# offset = (torch.rand(N, H_out, W_out, M*P*2).cuda() * 2 - 1)*0
mask_origin = torch.rand(N, H_out, W_out, M, P).cuda() + 1e-5
mask_origin = mask_origin.half()
mask = mask_origin
# mask = torch.nn.functional.softmax(mask_origin, dim=-1)
offset_mask = torch.cat([offset.unflatten(-1, (M, P * 2)), mask_origin.detach()], dim=-1).flatten(-2)
im2col_step = 128
input = input.half()
offset = offset.half()
mask = mask.half()
offset_mask = offset_mask.half()
input1 = input.detach()
def pad(om):
padded_zero = int(math.ceil(om.shape[3]/8)*8) - om.shape[3]
padded = torch.zeros(om.shape[0], om.shape[1], om.shape[2], padded_zero).to(om)
return torch.cat([om, padded], dim=-1)
dcnv4_args = [
input1, pad(offset_mask),
Kh, Kw, stride, stride, Kh // 2, Kw // 2, dilation, dilation, M, D, offset_scale,
im2col_step, remove_center, 8, 512, 2, 256, True, True,
]
output_flash_cuda = DCNv4Function.apply(*dcnv4_args)
print(f"test success")
# fwdok = torch.allclose(output_flash_cuda, output_pytorch, rtol=1e-2, atol=1e-3)
# max_abs_err = (output_flash_cuda - output_pytorch).abs().max()
# max_rel_err = ((output_flash_cuda - output_pytorch).abs() /
# (output_pytorch.abs()+ 1e-3)).max()
# print('>>> forward half')
# print(f'* {fwdok} check_forward_equal_with_pytorch_float: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}')
# assert(fwdok)
# test_args = edict({'warmup_num': 1000, 'test_num': 1000})
# exp_time_dcnv4 = speed_test(DCNv4Function.apply, test_args, dcnv4_args, name='exp')
# torch.cuda.synchronize()
# results = [{}]
# results[0]['dcnv3_time'] = exp_time_dcnv3
# results[0]['dcnv4_time'] = exp_time_dcnv4
# columns = list(results[0].keys())
# outputs = pd.DataFrame(results, columns=columns)
# with pd.option_context(
# 'display.max_rows', None, 'display.max_columns', None,
# 'display.max_colwidth', None, 'display.width', None,
# 'display.precision', 4, ):
# print(outputs)
if __name__ == '__main__':
check_forward_equal_with_pytorch_half()
# --------------------------------------------------------
# DCNv4
# Copyright (c) 2024 OpenGVLab
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------
from __future__ import absolute_import
from __future__ import print_function
from __future__ import division
import time
import torch
import torch.nn as nn
import math
from torch.autograd import gradcheck
import pandas as pd
from easydict import EasyDict as edict
from torch.cuda import Event
from functions import DCNv4Function, DCNv3Function
torch.set_printoptions(threshold=10000)
H_in, W_in = 56, 56
N, M, D = 64, 4, 32
# H_in, W_in = 28, 28
# N, M, D = 64, 16, 16
# H_in, W_in = 14, 14
# N, M, D = 64, 32, 16
# H_in, W_in = 7, 7
# N, M, D = 64, 64, 16
# H_in, W_in = 8, 8
# N, M, D = 128, 4, 16
Kh, Kw = 3, 3
remove_center = False
P = Kh * Kw - remove_center
offset_scale = 2.0
pad = 1
dilation = 1
stride = 1
H_out = (H_in + 2 * pad - (dilation * (Kh - 1) + 1)) // stride + 1
W_out = (W_in + 2 * pad - (dilation * (Kw - 1) + 1)) // stride + 1
torch.manual_seed(3)
def speed_test_backward(func, args, inputs, name='Unknown'):
# warmup
# for i in range(args.warmup_num):
# o = func(*inputs)
# o.sum().backward()
total_time = 0
len_input = len(inputs)
for i in range(args.warmup_num + args.test_num):
tic = Event(enable_timing=True)
toc = Event(enable_timing=True)
inputs[0] = inputs[0].detach()
inputs[0].requires_grad = True
if len_input > 1 and isinstance(inputs[1], torch.Tensor):
inputs[1] = inputs[1].detach()
inputs[1].requires_grad = True
if len_input > 2 and isinstance(inputs[2], torch.Tensor):
inputs[2] = inputs[2].detach()
inputs[2].requires_grad = True
o = func(*inputs)
torch.cuda.synchronize()
tic.record()
o.sum().backward()
toc.record()
torch.cuda.synchronize()
_time = tic.elapsed_time(toc)
if i >= args.warmup_num:
total_time += _time
o = o.detach()
# toc.record()
# torch.cuda.synchronize()
avg_time = total_time / args.test_num
#print(
# f'>>> {name: <10} finished {args.test_num} running, avg_time: {avg_time:.6f} ms')
return avg_time
# @torch.no_grad()
def check_forward_equal_with_pytorch_half():
"""
64x56x56x128(G=4)
2 64: 3.66
- offset_mask collection write 3.4022
- offset_mask collection 3.1968
"""
additions = [8, 128, 2, 256, False]
input = torch.rand(N, H_in, W_in, M*D).cuda() * 10
print(input.shape)
#offset = torch.rand(N, H_out, W_out, M*P*2).cuda() * 0
offset = (torch.rand(N, H_out, W_out, M*P*2).cuda() * 2 - 1)*2
mask_origin = torch.rand(N, H_out, W_out, M, P).cuda() + 1e-5
mask_origin = mask_origin.half()
mask_origin.requires_grad = True
# offset_mask = torch.cat([offset.unflatten(-1, (M, P, 2)), mask_origin.detach().unsqueeze(-1)], dim=-1).flatten(-3)
# mask /= mask.sum(-1, keepdim=True)
# mask = torch.nn.functional.softmax(mask_origin, dim=-1, dtype=torch.float32)
mask = mask_origin
# mask = mask.reshape(N, H_out, W_out, M*P)
# offset_mask = torch.cat([offset.unflatten(-1, (M, P, 2)), mask.detach().unsqueeze(-1)], dim=-1).flatten(-3)
offset_mask = torch.cat([offset.detach().unflatten(-1, (M, P * 2)), mask_origin.detach()], dim=-1).flatten(-2)
im2col_step = 128
input = input.half()
offset = offset.half()
mask = mask.half()
input.requires_grad = True
offset.requires_grad = True
# mask.requires_grad = True
output_pytorch = DCNv3Function.apply(
input,
offset,
mask,
Kh, Kw, stride, stride, Kh // 2, Kw // 2, dilation, dilation, M, D, offset_scale,
im2col_step, remove_center)#.detach().cpu()
(output_pytorch.sum()/10).backward()
def pad(om):
padded_zero = int(math.ceil(om.shape[3]/8)*8) - om.shape[3]
padded = torch.zeros(om.shape[0], om.shape[1], om.shape[2], padded_zero).to(om)
return torch.cat([om, padded], dim=-1)
# value_offset_mask = input.detach()
input1 = input.detach()
input1.requires_grad = True
offset_mask = offset_mask.half()
offset_mask.requires_grad = True
# offset_mask1.requires_grad = True
torch.cuda.profiler.cudart().cudaProfilerStart()
output_flash_cuda = DCNv4Function.apply(
input1, offset_mask,
Kh, Kw, stride, stride, Kh // 2, Kw // 2, dilation, dilation, M, D, offset_scale,
im2col_step, remove_center, *additions)#.detach().cpu()
(output_flash_cuda.sum()/10).backward()
torch.cuda.profiler.cudart().cudaProfilerStop()
input_grad = input.grad
input2_grad = input1.grad
bwdok = torch.allclose(input_grad.float(), input2_grad.float(), rtol=1e-2, atol=1e-3)
print("bwdok")
print(bwdok)
rel_err = (input_grad.abs() - input2_grad.abs())/(input_grad.abs()+1e-3)
print(rel_err.max())
offset_grad1 = offset.grad
offset_grad2 = offset_mask.grad.reshape(N, H_out, W_out, M, P*3)[..., :P*2].reshape(N, H_out, W_out, M*P*2)
# print(offset_grad1)
# print("====================")
# print(offset_grad2)
bwdok2 = torch.allclose(offset_grad1.float(), offset_grad2.float(), rtol=1e-2, atol=1e-3)
print("bwdok2")
print(bwdok2)
rel_err = (offset_grad1 - offset_grad2).abs() / (offset_grad1.abs()+1e-3)
print(rel_err.max())
mask_grad1 = mask_origin.grad
mask_grad2 = offset_mask.grad.reshape(N, H_out, W_out, M, P*3)[..., P*2:].reshape(N, H_out, W_out, M, P)
bwdok3 = torch.allclose(mask_grad1, mask_grad2, rtol=1e-2, atol=1e-3)
print("bwdok3")
print(bwdok3)
rel_err = (mask_grad1 - mask_grad2).abs() / (mask_grad1.abs()+1e-3)
print(rel_err.max())
fwdok = torch.allclose(output_flash_cuda, output_pytorch, rtol=1e-2, atol=1e-3)
max_abs_err = (output_flash_cuda - output_pytorch).abs().max()
max_rel_err = ((output_flash_cuda - output_pytorch).abs() /
(output_pytorch.abs()+ 1e-3)).max()
print('>>> forward half')
print(f'* {fwdok} check_forward_equal_with_pytorch_float: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}')
fn_args = [
input,
offset,
mask,
Kh, Kw, stride, stride, Kh // 2, Kw // 2, dilation, dilation, M, D, offset_scale,
im2col_step, remove_center
]
flash_dcn_fn_args = [
input1,
offset_mask,
Kh, Kw, stride, stride, Kh // 2, Kw // 2, dilation, dilation, M, D, offset_scale,
im2col_step, remove_center, *additions
]
test_args = edict({'warmup_num': 1000, 'test_num': 1000})
exp_time = speed_test_backward(
DCNv4Function.apply, test_args, flash_dcn_fn_args, name='exp')
exp_time_base = speed_test_backward(
DCNv3Function.apply, test_args, fn_args, name='exp')
results = [{}]
results[0]['time'] = exp_time
results[0]['time_base'] = exp_time_base
columns = list(results[0].keys())
outputs = pd.DataFrame(results, columns=columns)
with pd.option_context(
'display.max_rows', None, 'display.max_columns', None,
'display.max_colwidth', None, 'display.width', None,
'display.precision', 4, ):
print(outputs)
if __name__ == '__main__':
check_forward_equal_with_pytorch_half()
\ No newline at end of file
# ------------------------------------------------------------------------------------------------
# Deformable DETR
# Copyright (c) 2020 SenseTime. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------------------------------
# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
# ------------------------------------------------------------------------------------------------
from __future__ import absolute_import
from __future__ import print_function
from __future__ import division
from easydict import EasyDict as edict
from torch.cuda import Event
import pandas as pd
import time
import torch
import torch.nn as nn
from torch.autograd import gradcheck
from functions import MSDeformAttnFunction, FlashDeformAttnFunction, ms_deform_attn_core_pytorch
# N, M, D = 1, 4, 8
# # Lq, L, P = 2, 2, 2
# # shapes = torch.as_tensor([(6, 4), (3, 2)], dtype=torch.long).cuda()
# Lq, L, P = 1, 2, 8
# shapes = torch.as_tensor([(8, 16), (4, 8)], dtype=torch.long).cuda()
# N, M, D = 1, 8, 32
# # Lq, L, P = 2, 2, 2
# # shapes = torch.as_tensor([(6, 4), (3, 2)], dtype=torch.long).cuda()
# Lq, L, P = 300, 4, 4
# # shapes = torch.as_tensor([(134, 151), (67, 76), (34, 38), (17, 19)], dtype=torch.long).cuda()
# # shapes = torch.as_tensor([(134, 151), (67, 76), (34, 38), (16, 16)], dtype=torch.long).cuda()
# # shapes = torch.as_tensor([(134, 151), (67, 76), (34, 38), (17, 19)], dtype=torch.long).cuda()
# # shapes = torch.as_tensor([(17, 19), (4, 4)], dtype=torch.long).cuda()
# shapes = torch.as_tensor([(100, 151), (50, 76), (25, 38), (13, 19)], dtype=torch.long).cuda()
# # shapes = torch.as_tensor([(110, 151)], dtype=torch.long).cuda()
# B:6
# H:232
# W:400
# G:5
# D: 16
# channels: 80
# kernel: 3 points = 3 * 3
# num_split = 45 = kernel *kernel * G
H = 256
W = 256
N, M, D = 1, 8, 32
Lq, L, P = 100*152, 4, 8
shapes = torch.Tensor([[100, 152], [ 50, 76], [ 25, 38], [ 13, 19]]).long().cuda()
# x = x.reshape([B, H*W, G, D + self.num_split * 3])
# shapes = torch.as_tensor([(H, W)], dtype=torch.long).cuda()
# shapes = torch.as_tensor([(H, W), (H // 2, W // 2)], dtype=torch.long).cuda()
# shapes = torch.as_tensor([(H, W), (H // 2, W // 2), (H // 4, W // 4), (H // 8, W // 8)], dtype=torch.long).cuda()
level_start_index = torch.cat((shapes.new_zeros((1,)), shapes.prod(1).cumsum(0)[:-1]))
S = sum([(H * W).item() for H, W in shapes])
print(S)
def get_reference_points(spatial_shapes, device):
reference_points_list = []
for lvl, (H_, W_) in enumerate(spatial_shapes):
ref_y, ref_x = torch.meshgrid(torch.linspace(0.5, H_ - 0.5, H_, dtype=torch.float32, device=device),
torch.linspace(0.5, W_ - 0.5, W_, dtype=torch.float32, device=device))
ref_y = ref_y.reshape(-1)[None] / (H_)
ref_x = ref_x.reshape(-1)[None] / (W_)
ref = torch.stack((ref_x, ref_y), -1)
reference_points_list.append(ref)
reference_points = torch.cat(reference_points_list, 1)
# reference_points = reference_points[:, :, None] * valid_ratios[:, None]
return reference_points
torch.manual_seed(3)
@torch.no_grad()
def speed_test(func, args, inputs, name='Unknown'):
tic = Event(enable_timing=True)
toc = Event(enable_timing=True)
# warmup
for i in range(args.warmup_num):
func(*inputs)
tic.record()
for i in range(args.test_num):
func(*inputs)
toc.record()
torch.cuda.synchronize()
avg_time = tic.elapsed_time(toc) / args.test_num
print(
f'>>> {name: <10} finished {args.test_num} running, avg_time: {avg_time:.6f} ms')
return avg_time
@torch.no_grad()
def check_forward_equal_with_pytorch_half():
value = torch.rand(N, S, M, D).cuda() * 0.01
# offset = (torch.rand(N, Lq, M, L, P, 2).cuda() * 2 - 1) / 10
sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda()
attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5
sampling_loc_attn = torch.cat([sampling_locations.reshape(N, Lq, M, L*P*2), attention_weights.reshape(N, Lq, M, L*P)], dim=-1)
attention_weights = torch.nn.functional.softmax(attention_weights.flatten(-2, -1), dim=-1).unflatten(-1, (L, P))
im2col_step = 128
flash_fn_args = (
value.half(),
shapes,
level_start_index,
sampling_loc_attn.half(),
im2col_step,
P, 16
)
output_cuda = (
FlashDeformAttnFunction.apply(*flash_fn_args)
.detach()
.cpu()
).double()
fn_args = (
value,
shapes,
level_start_index,
sampling_locations,
attention_weights,
im2col_step,
)
output_pytorch = (
MSDeformAttnFunction.apply(*fn_args)
.detach().double()
.cpu()
)
max_abs_err = (output_cuda - output_pytorch).abs().max()
max_rel_err = ((output_cuda - output_pytorch).abs() / output_pytorch.abs()).max()
fwdok = torch.allclose(output_cuda, output_pytorch, rtol=1e-2, atol=1e-3)
print(
f"* {fwdok} check_forward_equal_with_pytorch_float: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}"
)
test_args = edict({'warmup_num': 1000, 'test_num': 1000})
exp_time_base = speed_test(
MSDeformAttnFunction.apply, test_args, fn_args, name='exp')
exp_time = speed_test(
FlashDeformAttnFunction.apply, test_args, flash_fn_args, name='exp')
results = [{}]
results[0]['time'] = exp_time
results[0]['time_base'] = exp_time_base
columns = list(results[0].keys())
outputs = pd.DataFrame(results, columns=columns)
with pd.option_context(
'display.max_rows', None, 'display.max_columns', None,
'display.max_colwidth', None, 'display.width', None,
'display.precision', 4, ):
print(outputs)
if __name__ == "__main__":
check_forward_equal_with_pytorch_half()
# ------------------------------------------------------------------------------------------------
# Deformable DETR
# Copyright (c) 2020 SenseTime. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------------------------------
# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
# ------------------------------------------------------------------------------------------------
from __future__ import absolute_import
from __future__ import print_function
from __future__ import division
from easydict import EasyDict as edict
from torch.cuda import Event
import pandas as pd
import time
import torch
import torch.nn as nn
from torch.autograd import gradcheck
from functions import MSDeformAttnFunction, ms_deform_attn_core_pytorch, FlashDeformAttnFunction
H = 256
W = 256
N, M, D = 1, 8, 16
Lq, L, P = H * W, 1, 8
# x = x.reshape([B, H*W, G, D + self.num_split * 3])
shapes = torch.as_tensor([(H, W)], dtype=torch.long).cuda()
# shapes = torch.as_tensor([(H, W), (H // 2, W // 2)], dtype=torch.long).cuda()
# shapes = torch.as_tensor([(H, W), (H // 2, W // 2), (H // 4, W // 4), (H // 8, W // 8)], dtype=torch.long).cuda()
H = 256
W = 256
N, M, D = 1, 8, 32
Lq, L, P = 100*152, 4, 8
shapes = torch.Tensor([[100, 152], [ 50, 76], [ 25, 38], [ 13, 19]]).long().cuda()
level_start_index = torch.cat((shapes.new_zeros((1,)), shapes.prod(1).cumsum(0)[:-1]))
S = sum([(H * W).item() for H, W in shapes])
def get_reference_points(spatial_shapes, device):
reference_points_list = []
for lvl, (H_, W_) in enumerate(spatial_shapes):
ref_y, ref_x = torch.meshgrid(torch.linspace(0.5, H_ - 0.5, H_, dtype=torch.float32, device=device),
torch.linspace(0.5, W_ - 0.5, W_, dtype=torch.float32, device=device))
ref_y = ref_y.reshape(-1)[None] / (H_)
ref_x = ref_x.reshape(-1)[None] / (W_)
ref = torch.stack((ref_x, ref_y), -1)
reference_points_list.append(ref)
reference_points = torch.cat(reference_points_list, 1)
# reference_points = reference_points[:, :, None] * valid_ratios[:, None]
return reference_points
torch.manual_seed(3)
@torch.no_grad()
def speed_test(func, args, inputs, name='Unknown'):
tic = Event(enable_timing=True)
toc = Event(enable_timing=True)
# warmup
for i in range(args.warmup_num):
func(*inputs)
tic.record()
for i in range(args.test_num):
func(*inputs)
toc.record()
torch.cuda.synchronize()
avg_time = tic.elapsed_time(toc) / args.test_num
print(
f'>>> {name: <10} finished {args.test_num} running, avg_time: {avg_time:.6f} ms')
return avg_time
def check_forward_equal_with_pytorch_half():
value = torch.rand(N, S, M, D).cuda() * 0.01
offset = (torch.rand(N, Lq, M, L, P, 2).cuda() * 2 - 1) / 10
sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda()
attention_weights_origin = torch.rand(N, Lq, M, L, P).cuda() + 1e-5
attention_weights_origin.requires_grad = True
sampling_loc_attn = torch.cat([sampling_locations.detach().reshape(N, Lq, M, L*P*2), attention_weights_origin.detach().reshape(N, Lq, M, L*P)], dim=-1)
attention_weights = torch.nn.functional.softmax(attention_weights_origin.flatten(-2, -1), dim=-1).unflatten(-1, (L, P))
im2col_step = 128
value.requires_grad = True
sampling_loc_attn.requires_grad = True
output_cuda = (
FlashDeformAttnFunction.apply(
value.float(),
shapes,
level_start_index,
sampling_loc_attn.float(),
im2col_step,
)
)
(output_cuda.float().sum()/10).backward()
value1 = value.detach()
value1.requires_grad = True
sampling_locations.requires_grad = True
#attention_weights.requires_grad = True
output_pytorch = (
ms_deform_attn_core_pytorch(value1, shapes, sampling_locations, attention_weights)
)
(output_pytorch.sum()/10).backward()
max_abs_err = (output_cuda.float() - output_pytorch).abs().max()
max_rel_err = ((output_cuda.float() - output_pytorch).abs() / output_pytorch.abs()).max()
fwdok = torch.allclose(output_cuda.float(), output_pytorch, rtol=1e-2, atol=1e-3)
print(fwdok)
print(max_abs_err, max_rel_err)
#exit()
bwdok1 = torch.allclose(value.grad, value1.grad, rtol=1e-2, atol=1e-3)
print(bwdok1)
# rel_err = (sampling_locations.grad - sampling_loc_attn.grad[..., :L*P*2].reshape(*sampling_locations.shape)).abs()/(sampling_locations.grad.abs()+1e-3)
# print(rel_err.max())
locgrad1 = sampling_locations.grad
locgrad2 = sampling_loc_attn.grad[..., :L*P*2].reshape(*sampling_locations.shape)
bwdok2 = torch.allclose(locgrad1, locgrad2, rtol=1e-2, atol=1e-3)
print(bwdok2)
rel_err = (locgrad1 - locgrad2).abs()/(locgrad1.abs()+1e-3)
print(rel_err.max())
attngrad1 = attention_weights_origin.grad
attngrad2 = sampling_loc_attn.grad[..., L*P*2:].reshape(*attention_weights_origin.shape)
bwdok3 = torch.allclose(locgrad1, locgrad2, rtol=1e-2, atol=1e-3)
print(bwdok3)
rel_err = (attngrad1 - attngrad2).abs()/(attngrad1.abs()+1e-3)
print(rel_err.max())
exit()
#exit()
# pdb.set_trace()
max_abs_err = (output_cuda - output_pytorch).abs().max()
max_rel_err = ((output_cuda - output_pytorch).abs() / output_pytorch.abs()).max()
fwdok = torch.allclose(output_cuda, output_pytorch, rtol=1e-2, atol=1e-3)
print(
f"* {fwdok} check_forward_equal_with_pytorch_float: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}"
)
fn_args = (
value,
shapes,
level_start_index,
sampling_locations,
attention_weights,
im2col_step,
)
flash_dcn_fn_args = (
value.half(),
shapes,
level_start_index,
sampling_loc_attn.half(),
im2col_step,
)
test_args = edict({'warmup_num': 50, 'test_num': 100})
exp_time = speed_test(
FlashMSDeformAttnFunction.apply, test_args, flash_dcn_fn_args, name='exp')
exp_time_base = speed_test(
MSDeformAttnFunction.apply, test_args, fn_args, name='exp')
results = [{}]
results[0]['time'] = exp_time
results[0]['time_base'] = exp_time_base
columns = list(results[0].keys())
outputs = pd.DataFrame(results, columns=columns)
with pd.option_context(
'display.max_rows', None, 'display.max_columns', None,
'display.max_colwidth', None, 'display.width', None,
'display.precision', 4, ):
print(outputs)
if __name__ == "__main__":
check_forward_equal_with_pytorch_half()
\ No newline at end of file
# ------------------------------------------------------------------------------------------------
# Deformable Convolution v4
# Copyright (c) 2024 OpenGVLab
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------------------------------
# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
# ------------------------------------------------------------------------------------------------
import os
import glob
import torch
# 导入打包相关库
from setuptools import find_packages, setup
from torch.utils.cpp_extension import CUDA_HOME, CppExtension, CUDAExtension
# 定义获取扩展的函数(保持原样,供非打包模式使用)
def get_extensions():
this_dir = os.path.dirname(os.path.abspath(__file__))
extensions_dir = os.path.join(this_dir, "src")
main_file = glob.glob(os.path.join(extensions_dir, "*.cpp"))
source_cpu = glob.glob(os.path.join(extensions_dir, "cpu", "*.cpp"))
source_cuda = glob.glob(os.path.join(extensions_dir, "cuda", "*.cu"))
sources = main_file + source_cpu
extension = CppExtension
extra_compile_args = {"cxx": []}
define_macros = []
if torch.cuda.is_available() and CUDA_HOME is not None:
extension = CUDAExtension
sources += source_cuda
define_macros += [("WITH_CUDA", None)]
extra_compile_args["nvcc"] = [
"-DCUDA_HAS_FP16=1",
"-D__CUDA_NO_HALF_OPERATORS__",
"-D__CUDA_NO_HALF_CONVERSIONS__",
"-D__CUDA_NO_HALF2_OPERATORS__",
"-O3",
]
else:
raise NotImplementedError('Cuda is not available')
sources = [os.path.join(extensions_dir, s) for s in sources]
include_dirs = [extensions_dir]
ext_modules = [
extension(
"DCNv4.ext", # 注意:这里保持原模块名,方便后面替换
sources,
include_dirs=include_dirs,
define_macros=define_macros,
extra_compile_args=extra_compile_args,
)
]
return ext_modules
# --- 核心修改逻辑 ---
# 检查是否是构建 Wheel 的模式
# 如果是构建 Wheel,我们不编译,而是将现有的 .so 作为包数据处理
# 注意:setuptools 打包扩展模块和打包数据文件的逻辑是冲突的,所以我们需要在构建 Wheel 时禁用 ext_modules
if __name__ == "__main__":
# 检查环境变量,决定是否跳过编译
# 你也可以直接写一个布尔值,或者检查某个文件是否存在
build_so = int(os.getenv("DCNv4_BUILD_SO", "0"))
# 准备参数
kwargs = {
"name": "DCNv4",
"version": "1.0.0.post2",
"author": "Yuwen Xiong, Feng Wang",
"url": "",
"description": "PyTorch Wrapper for CUDA Functions of DCNv4",
"packages": ['DCNv4', 'DCNv4/functions', 'DCNv4/modules'],
"package_data": {
"DCNv4": ["ext.so"], # 假设 ext.so 生成在 DCNv4 目录下
# "DCNv4": ["ext.cpython-310-x86_64-linux-gnu.so"], # 假设 ext.so 生成在 DCNv4 目录下
},
"cmdclass": {"build_ext": torch.utils.cpp_extension.BuildExtension},
# 确保生成正确的 .dist-info
"zip_safe": False,
# 添加以下参数来避免生成 .egg-info 在当前目录
"options": {
'egg_info': {
'egg_base': '/tmp' # 将 egg-info 生成到临时目录
}
},
}
if build_so:
# 正常开发模式,进行编译
kwargs["ext_modules"] = get_extensions()
else:
print("=== BUILD WHEEL MODE: Skipping compilation, using existing ext.so ===")
# 在构建 Wheel 时,不要传入 ext_modules
# 我们依赖 MANIFEST.in 或 package_data 将 .so 文件包含进去
# 但是 setuptools 的 bdist_wheel 默认会忽略 .so,所以我们需要确保 .so 在包目录里
# 这里我们不传入 ext_modules,而是依靠外部脚本或 MANIFEST.in
# 更简单的方法:直接在 setup 里不写 ext_modules,确保 .so 已经在 DCNv4/ 目录下
kwargs["ext_modules"] = [] # 强制不编译
setup(**kwargs)
\ No newline at end of file
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
/*!
**************************************************************************************************
* Deformable DETR
* Copyright (c) 2020 SenseTime. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
**************************************************************************************************
* Modified from
*https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
**************************************************************************************************
*/
#pragma once
#include <torch/extension.h>
at::Tensor dcnv4_cuda_forward(
const at::Tensor &value,
const at::Tensor &p_offset,
const int kernel_h, const int kernel_w, const int stride_h,
const int stride_w, const int pad_h, const int pad_w, const int dilation_h,
const int dilation_w, const int group, const int group_channels,
const float offset_scale, const int im2col_step, const int remove_center,
const int d_stride, const int block_thread, const bool softmax);
std::vector<at::Tensor>
dcnv4_cuda_backward(
const at::Tensor &value,
const at::Tensor &p_offset,
const int kernel_h, const int kernel_w, const int stride_h,
const int stride_w, const int pad_h, const int pad_w, const int dilation_h,
const int dilation_w, const int group, const int group_channels,
const float offset_scale, const int im2col_step, const at::Tensor &grad_output,
const int remove_center, const int d_stride, const int block_thread,
const bool softmax);
\ No newline at end of file
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment