Commit 5b17e272 authored by wangkx1's avatar wangkx1
Browse files

init

parents
# --------------------------------------------------------
# InternImage
# Copyright (c) 2022 OpenGVLab
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------
from __future__ import absolute_import
from __future__ import print_function
from __future__ import division
import torch
import math
import torch.nn.functional as F
from torch.autograd import Function
from torch.autograd.function import once_differentiable
from torch.cuda.amp import custom_bwd, custom_fwd
from .table import TABLE, BWDTABLE
from DCNv4 import ext
def factors(N):
res = []
for i in range(1, N+1):
if N % i == 0:
res.append(i)
return res
def findspec(B, H, W, G, C):
key = f"{B}x{H}x{W}x{G}x{C}"
if key in TABLE:
return TABLE[key][0], TABLE[key][1]
d_stride = 8
ms = factors(B*H*W)
multiplier = 1
for m in ms:
if m <= 64 and (m * G * C // d_stride) <= 512:
multiplier = m
n_thread = multiplier * G * C // d_stride
key = f"{B}x{H}x{W}x{G}x{C}"
TABLE[key] = (d_stride, n_thread)
return d_stride, n_thread
def find_spec_bwd(B, H, W, G, C):
key = f"{B}x{H}x{W}x{G}x{C}"
if key in BWDTABLE:
return BWDTABLE[key][0], BWDTABLE[key][1]
if C >= 64:
d_stride = 2
else:
d_stride = 1
ms = factors(B*H*W)
multiplier = 1
for m in ms:
if m <= 64 and (m * G * C // d_stride) <= 256:
multiplier = m
n_thread = multiplier * G * C // d_stride
return d_stride, n_thread
class DCNv4Function(Function):
@staticmethod
@custom_fwd
def forward(
ctx, input, offset_mask,
kernel_h, kernel_w, stride_h, stride_w,
pad_h, pad_w, dilation_h, dilation_w,
group, group_channels, offset_scale,
im2col_step, remove_center):
forward_d_stride, forward_block_thread = findspec(input.shape[0], input.shape[1], input.shape[2], group, group_channels)
backward_d_stride, backward_block_thread = find_spec_bwd(input.shape[0], input.shape[1], input.shape[2], group, group_channels)
ctx.kernel_h = kernel_h
ctx.kernel_w = kernel_w
ctx.stride_h = stride_h
ctx.stride_w = stride_w
ctx.pad_h = pad_h
ctx.pad_w = pad_w
ctx.dilation_h = dilation_h
ctx.dilation_w = dilation_w
ctx.group = group
ctx.group_channels = group_channels
ctx.offset_scale = offset_scale
ctx.im2col_step = im2col_step
ctx.remove_center = remove_center
ctx.backward_d_stride = backward_d_stride
ctx.backward_block_thread = backward_block_thread
args = [
input, offset_mask, kernel_h,
kernel_w, stride_h, stride_w, pad_h,
pad_w, dilation_h, dilation_w, group,
group_channels, offset_scale,
ctx.im2col_step,
remove_center,
forward_d_stride,
forward_block_thread,
False,
]
output = ext.dcnv4_forward(*args)
ctx.save_for_backward(input, offset_mask)
return output
@staticmethod
@once_differentiable
@custom_bwd
def backward(ctx, grad_output):
input, offset_mask = ctx.saved_tensors
args = [
input, offset_mask, ctx.kernel_h,
ctx.kernel_w, ctx.stride_h, ctx.stride_w, ctx.pad_h,
ctx.pad_w, ctx.dilation_h, ctx.dilation_w, ctx.group,
ctx.group_channels, ctx.offset_scale, ctx.im2col_step,
grad_output.contiguous(), ctx.remove_center,
ctx.backward_d_stride, ctx.backward_block_thread,
False
]
grad_input, grad_offset_mask = \
ext.dcnv4_backward(*args)
return grad_input, grad_offset_mask, \
None, None, None, None, None, None, None,\
None, None, None, None, None, None
# ------------------------------------------------------------------------------------------------
# Deformable DETR
# Copyright (c) 2020 SenseTime. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------------------------------
# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
# ------------------------------------------------------------------------------------------------
from __future__ import absolute_import
from __future__ import print_function
from __future__ import division
import torch
import torch.nn.functional as F
from torch.autograd import Function
from torch.autograd.function import once_differentiable
import numpy as np
from DCNv4 import ext
shm_size_dict = {
"8.0": 163000,
"8.6": 99000,
"8.7": 163000,
"8.9": 99000,
"9.0": 227000,
"7.5": 64000,
"7.0": 96000,
}
cuda_capability = f"{torch.cuda.get_device_properties(0).major}.{torch.cuda.get_device_properties(0).minor}"
cuda_capability = "8.7"
if cuda_capability not in shm_size_dict:
raise NotImplementedError
shm_size_cap = shm_size_dict[cuda_capability]
def factors(N):
res = []
for i in range(1, N+1):
if N % i == 0:
res.append(i)
return res
def findspec(B, Q, G, C):
d_stride = 8
ms = factors(B*Q)
multiplier = 1
for m in ms:
if m <= 64 and (m * G * C // d_stride) <= 512:
multiplier = m
n_thread = multiplier * G * C // d_stride
return d_stride, n_thread
def findspec_bwd(B, Q, G, C):
if C >= 64:
d_stride = 2
else:
d_stride = 1
ms = factors(B*Q)
multiplier = 1
for m in ms:
if m <= 64 and (m * G * C // d_stride) <= 256:
multiplier = m
n_thread = multiplier * G * C // d_stride
return d_stride, n_thread
class FlashDeformAttnFunction(Function):
@staticmethod
@torch.autocast("cuda", enabled=True, dtype=torch.float16)
def forward(
ctx, value, value_spatial_shapes, value_level_start_index,
sampling_loc_attn, im2col_step, K=8
):
ctx.im2col_step = im2col_step
ctx.K = K
d_stride, blockthread = findspec(value.shape[0], sampling_loc_attn.shape[1], value.shape[2], value.shape[3])
d_stride_backward, blockthread_backward = findspec_bwd(value.shape[0], sampling_loc_attn.shape[1], value.shape[2], value.shape[3])
ctx.d_stride_backward = d_stride_backward
ctx.blockthread_backward = blockthread_backward
output = ext.flash_deform_attn_forward(
value,
value_spatial_shapes,
value_level_start_index,
sampling_loc_attn,
ctx.im2col_step,
K,
d_stride,
blockthread,
)
ctx.save_for_backward(value, value_spatial_shapes, value_level_start_index, sampling_loc_attn)
return output
@staticmethod
@once_differentiable
def backward(ctx, grad_output):
value, value_spatial_shapes, value_level_start_index, sampling_loc_attn = ctx.saved_tensors
grad_value, grad_sampling_loc_attn = ext.flash_deform_attn_backward(
value,
value_spatial_shapes,
value_level_start_index,
sampling_loc_attn,
grad_output.contiguous(),
ctx.im2col_step,
ctx.K,
ctx.d_stride_backward,
ctx.blockthread_backward,
)
return grad_value, None, None, grad_sampling_loc_attn, None, None
TABLE = {
"64x56x56x4x16": [
8,
448,
56
],
"64x28x28x4x16": [
8,
448,
56
],
"64x14x14x4x16": [
8,
32,
4
],
"64x7x7x4x16": [
8,
56,
7
],
"1x200x320x4x16": [
8,
32,
4
],
"1x100x160x4x16": [
8,
32,
4
],
"1x50x80x4x16": [
4,
512,
32
],
"1x25x40x4x16": [
4,
320,
20
],
"1x64x64x4x16": [
8,
512,
64
],
"64x56x56x5x16": [
8,
490,
49
],
"64x28x28x5x16": [
8,
490,
49
],
"64x14x14x5x16": [
8,
280,
28
],
"64x7x7x5x16": [
4,
140,
7
],
"1x200x320x5x16": [
4,
400,
20
],
"1x100x160x5x16": [
4,
400,
20
],
"1x50x80x5x16": [
8,
500,
50
],
"1x25x40x5x16": [
8,
20,
2
],
"1x64x64x5x16": [
8,
320,
32
],
"64x56x56x6x16": [
8,
768,
64
],
"64x28x28x6x16": [
8,
672,
56
],
"64x14x14x6x16": [
8,
336,
28
],
"64x7x7x6x16": [
8,
84,
7
],
"1x200x320x6x16": [
4,
600,
25
],
"1x100x160x6x16": [
4,
600,
25
],
"1x50x80x6x16": [
8,
600,
50
],
"1x25x40x6x16": [
2,
240,
5
],
"1x64x64x6x16": [
8,
384,
32
],
"64x56x56x7x16": [
8,
896,
64
],
"64x28x28x7x16": [
8,
686,
49
],
"64x14x14x7x16": [
8,
392,
28
],
"64x7x7x7x16": [
8,
686,
49
],
"1x200x320x7x16": [
8,
700,
50
],
"1x100x160x7x16": [
8,
700,
50
],
"1x50x80x7x16": [
8,
700,
50
],
"1x25x40x7x16": [
8,
70,
5
],
"1x64x64x7x16": [
8,
448,
32
],
"64x56x56x8x16": [
8,
448,
28
],
"64x28x28x8x16": [
8,
448,
28
],
"64x14x14x8x16": [
8,
448,
28
],
"64x7x7x8x16": [
8,
784,
49
],
"1x200x320x8x16": [
8,
800,
50
],
"1x100x160x8x16": [
4,
640,
20
],
"1x50x80x8x16": [
8,
800,
50
],
"1x25x40x8x16": [
4,
64,
2
],
"1x64x64x8x16": [
8,
256,
16
],
"64x56x56x4x32": [
8,
448,
28
],
"64x28x28x4x32": [
8,
448,
28
],
"64x14x14x4x32": [
8,
448,
28
],
"64x7x7x4x32": [
8,
112,
7
],
"1x200x320x4x32": [
8,
512,
32
],
"1x100x160x4x32": [
8,
800,
50
],
"1x50x80x4x32": [
8,
800,
50
],
"1x25x40x4x32": [
4,
128,
4
],
"1x64x64x4x32": [
8,
128,
8
],
"64x56x56x5x32": [
8,
560,
28
],
"64x28x28x5x32": [
8,
560,
28
],
"64x14x14x5x32": [
8,
560,
28
],
"64x7x7x5x32": [
8,
980,
49
],
"1x200x320x5x32": [
8,
500,
25
],
"1x100x160x5x32": [
8,
800,
40
],
"1x50x80x5x32": [
8,
1000,
50
],
"1x25x40x5x32": [
4,
200,
5
],
"1x64x64x5x32": [
8,
640,
32
],
"64x56x56x6x32": [
8,
336,
14
],
"64x28x28x6x32": [
8,
336,
14
],
"64x14x14x6x32": [
8,
336,
14
],
"64x7x7x6x32": [
16,
588,
49
],
"1x200x320x6x32": [
8,
480,
20
],
"1x100x160x6x32": [
8,
480,
20
],
"1x50x80x6x32": [
16,
600,
50
],
"1x25x40x6x32": [
8,
96,
4
],
"1x64x64x6x32": [
8,
768,
32
],
"64x56x56x7x32": [
8,
448,
16
],
"64x28x28x7x32": [
8,
448,
16
],
"64x14x14x7x32": [
8,
196,
7
],
"64x7x7x7x32": [
8,
28,
1
],
"1x200x320x7x32": [
8,
448,
16
],
"1x100x160x7x32": [
8,
448,
16
],
"1x50x80x7x32": [
8,
700,
25
],
"1x25x40x7x32": [
8,
56,
2
],
"1x64x64x7x32": [
8,
896,
32
],
"64x56x56x8x32": [
8,
448,
14
],
"64x28x28x8x32": [
8,
448,
14
],
"64x14x14x8x32": [
8,
448,
14
],
"64x7x7x8x32": [
8,
32,
1
],
"1x200x320x8x32": [
8,
512,
16
],
"1x100x160x8x32": [
8,
800,
25
],
"1x50x80x8x32": [
8,
800,
25
],
"1x25x40x8x32": [
4,
512,
8
],
"1x64x64x8x32": [
8,
32,
1
],
"64x56x56x4x64": [
8,
448,
14
],
"64x28x28x4x64": [
8,
448,
14
],
"64x14x14x4x64": [
8,
448,
14
],
"64x7x7x4x64": [
8,
32,
1
],
"1x200x320x4x64": [
8,
512,
16
],
"1x100x160x4x64": [
8,
512,
16
],
"1x50x80x4x64": [
8,
800,
25
],
"1x25x40x4x64": [
8,
640,
20
],
"1x64x64x4x64": [
8,
512,
16
],
"64x56x56x5x64": [
8,
560,
14
],
"64x28x28x5x64": [
8,
560,
14
],
"64x14x14x5x64": [
8,
560,
14
],
"64x7x7x5x64": [
8,
280,
7
],
"1x200x320x5x64": [
8,
800,
20
],
"1x100x160x5x64": [
8,
800,
20
],
"1x50x80x5x64": [
8,
1000,
25
],
"1x25x40x5x64": [
8,
80,
2
],
"1x64x64x5x64": [
8,
320,
8
],
"64x56x56x6x64": [
8,
768,
16
],
"64x28x28x6x64": [
8,
768,
16
],
"64x14x14x6x64": [
8,
336,
7
],
"64x7x7x6x64": [
8,
336,
7
],
"1x200x320x6x64": [
8,
768,
16
],
"1x100x160x6x64": [
8,
480,
10
],
"1x50x80x6x64": [
16,
240,
10
],
"1x25x40x6x64": [
8,
240,
5
],
"1x64x64x6x64": [
8,
768,
16
],
"64x56x56x7x64": [
8,
896,
16
],
"64x28x28x7x64": [
8,
448,
8
],
"64x14x14x7x64": [
8,
392,
7
],
"64x7x7x7x64": [
8,
56,
1
],
"1x200x320x7x64": [
8,
896,
16
],
"1x100x160x7x64": [
8,
448,
8
],
"1x50x80x7x64": [
8,
448,
8
],
"1x25x40x7x64": [
8,
448,
8
],
"1x64x64x7x64": [
8,
448,
8
],
"64x56x56x8x64": [
8,
896,
14
],
"64x28x28x8x64": [
8,
896,
14
],
"64x14x14x8x64": [
8,
448,
7
],
"64x7x7x8x64": [
8,
64,
1
],
"1x200x320x8x64": [
8,
512,
8
],
"1x100x160x8x64": [
8,
512,
8
],
"1x50x80x8x64": [
8,
512,
8
],
"1x25x40x8x64": [
8,
512,
8
],
"1x64x64x8x64": [
8,
512,
8
]
}
BWDTABLE = {
"64x56x56x4x16": [
1,
256,
4
],
"64x56x56x5x16": [
1,
320,
4
],
"64x56x56x6x16": [
1,
192,
2
],
"64x56x56x7x16": [
1,
224,
2
],
"64x56x56x8x16": [
1,
256,
2
],
"64x56x56x4x32": [
1,
256,
2
],
"64x56x56x5x32": [
1,
160,
1
],
"64x56x56x6x32": [
1,
192,
1
],
"64x56x56x7x32": [
1,
224,
1
],
"64x56x56x8x32": [
1,
256,
1
],
"64x56x56x4x64": [
2,
512,
4
],
"64x56x56x5x64": [
2,
640,
4
],
"64x56x56x6x64": [
2,
384,
2
],
"64x56x56x7x64": [
2,
224,
1
],
"64x56x56x8x64": [
2,
1024,
4
],
"64x28x28x4x16": [
1,
128,
2
],
"64x28x28x5x16": [
1,
320,
4
],
"64x28x28x6x16": [
1,
96,
1
],
"64x28x28x7x16": [
1,
224,
2
],
"64x28x28x8x16": [
1,
128,
1
],
"64x28x28x4x32": [
1,
128,
1
],
"64x28x28x5x32": [
1,
320,
2
],
"64x28x28x6x32": [
1,
192,
1
],
"64x28x28x7x32": [
1,
224,
1
],
"64x28x28x8x32": [
1,
256,
1
],
"64x28x28x4x64": [
2,
512,
4
],
"64x28x28x5x64": [
2,
640,
4
],
"64x28x28x6x64": [
2,
384,
2
],
"64x28x28x7x64": [
2,
224,
1
],
"64x28x28x8x64": [
2,
512,
2
],
"64x14x14x4x16": [
1,
128,
2
],
"64x14x14x5x16": [
1,
320,
4
],
"64x14x14x6x16": [
1,
192,
2
],
"64x14x14x7x16": [
1,
224,
2
],
"64x14x14x8x16": [
1,
128,
1
],
"64x14x14x4x32": [
1,
256,
2
],
"64x14x14x5x32": [
1,
160,
1
],
"64x14x14x6x32": [
1,
192,
1
],
"64x14x14x7x32": [
1,
224,
1
],
"64x14x14x8x32": [
1,
256,
1
],
"64x14x14x4x64": [
2,
128,
1
],
"64x14x14x5x64": [
2,
160,
1
],
"64x14x14x6x64": [
2,
384,
2
],
"64x14x14x7x64": [
2,
224,
1
],
"64x14x14x8x64": [
2,
256,
1
],
"64x7x7x4x16": [
4,
784,
49
],
"64x7x7x5x16": [
2,
280,
7
],
"64x7x7x6x16": [
2,
48,
1
],
"64x7x7x7x16": [
2,
392,
7
],
"64x7x7x8x16": [
1,
128,
1
],
"64x7x7x4x32": [
1,
128,
1
],
"64x7x7x5x32": [
1,
160,
1
],
"64x7x7x6x32": [
2,
96,
1
],
"64x7x7x7x32": [
2,
112,
1
],
"64x7x7x8x32": [
2,
128,
1
],
"64x7x7x4x64": [
2,
896,
7
],
"64x7x7x5x64": [
2,
160,
1
],
"64x7x7x6x64": [
2,
192,
1
],
"64x7x7x7x64": [
2,
224,
1
],
"64x7x7x8x64": [
2,
256,
1
],
"1x200x320x4x16": [
1,
320,
5
],
"1x200x320x5x16": [
1,
320,
4
],
"1x200x320x6x16": [
1,
96,
1
],
"1x200x320x7x16": [
1,
224,
2
],
"1x200x320x8x16": [
1,
640,
5
],
"1x200x320x4x32": [
1,
128,
1
],
"1x200x320x5x32": [
1,
320,
2
],
"1x200x320x6x32": [
1,
384,
2
],
"1x200x320x7x32": [
1,
224,
1
],
"1x200x320x8x32": [
1,
256,
1
],
"1x200x320x4x64": [
2,
640,
5
],
"1x200x320x5x64": [
2,
800,
5
],
"1x200x320x6x64": [
2,
768,
4
],
"1x200x320x7x64": [
2,
448,
2
],
"1x200x320x8x64": [
2,
1024,
4
],
"1x100x160x4x16": [
1,
320,
5
],
"1x100x160x5x16": [
1,
640,
8
],
"1x100x160x6x16": [
1,
96,
1
],
"1x100x160x7x16": [
1,
224,
2
],
"1x100x160x8x16": [
1,
640,
5
],
"1x100x160x4x32": [
1,
256,
2
],
"1x100x160x5x32": [
1,
160,
1
],
"1x100x160x6x32": [
1,
384,
2
],
"1x100x160x7x32": [
1,
224,
1
],
"1x100x160x8x32": [
1,
512,
2
],
"1x100x160x4x64": [
2,
128,
1
],
"1x100x160x5x64": [
2,
160,
1
],
"1x100x160x6x64": [
2,
384,
2
],
"1x100x160x7x64": [
2,
448,
2
],
"1x100x160x8x64": [
2,
512,
2
],
"1x50x80x4x16": [
1,
320,
5
],
"1x50x80x5x16": [
1,
320,
4
],
"1x50x80x6x16": [
1,
96,
1
],
"1x50x80x7x16": [
1,
112,
1
],
"1x50x80x8x16": [
1,
512,
4
],
"1x50x80x4x32": [
1,
128,
1
],
"1x50x80x5x32": [
1,
320,
2
],
"1x50x80x6x32": [
1,
384,
2
],
"1x50x80x7x32": [
1,
224,
1
],
"1x50x80x8x32": [
1,
256,
1
],
"1x50x80x4x64": [
2,
256,
2
],
"1x50x80x5x64": [
2,
640,
4
],
"1x50x80x6x64": [
2,
768,
4
],
"1x50x80x7x64": [
2,
448,
2
],
"1x50x80x8x64": [
2,
1024,
4
],
"1x25x40x4x16": [
1,
320,
5
],
"1x25x40x5x16": [
2,
400,
10
],
"1x25x40x6x16": [
1,
192,
2
],
"1x25x40x7x16": [
4,
224,
8
],
"1x25x40x8x16": [
4,
160,
5
],
"1x25x40x4x32": [
2,
128,
2
],
"1x25x40x5x32": [
1,
320,
2
],
"1x25x40x6x32": [
2,
96,
1
],
"1x25x40x7x32": [
2,
112,
1
],
"1x25x40x8x32": [
2,
640,
5
],
"1x25x40x4x64": [
2,
128,
1
],
"1x25x40x5x64": [
2,
160,
1
],
"1x25x40x6x64": [
2,
192,
1
],
"1x25x40x7x64": [
2,
896,
4
],
"1x25x40x8x64": [
2,
512,
2
],
"1x64x64x4x16": [
1,
256,
4
],
"1x64x64x5x16": [
2,
40,
1
],
"1x64x64x6x16": [
1,
192,
2
],
"1x64x64x7x16": [
1,
224,
2
],
"1x64x64x8x16": [
1,
512,
4
],
"1x64x64x4x32": [
2,
64,
1
],
"1x64x64x5x32": [
1,
320,
2
],
"1x64x64x6x32": [
1,
192,
1
],
"1x64x64x7x32": [
1,
224,
1
],
"1x64x64x8x32": [
1,
256,
1
],
"1x64x64x4x64": [
2,
512,
4
],
"1x64x64x5x64": [
2,
640,
4
],
"1x64x64x6x64": [
2,
192,
1
],
"1x64x64x7x64": [
2,
224,
1
],
"1x64x64x8x64": [
2,
256,
1
]
}
\ No newline at end of file
python search_dcnv4_bwd_engine.py > res_bwd.txt
python find_best.py --input res_bwd.txt --output table_bwd.py
\ No newline at end of file
from __future__ import absolute_import
from __future__ import print_function
from __future__ import division
import time
import math
import torch
import torch.nn as nn
import math
from torch.autograd import gradcheck
import pandas as pd
from easydict import EasyDict as edict
import argparse
from torch.cuda import Event
from functions.dcnv3_func import DCNv3Function, dcnv3_core_pytorch
from functions.dcnv4_func import DCNv4Function
torch.set_printoptions(threshold=10000)
torch.manual_seed(3)
#@torch.no_grad()
def speed_test(func, args, inputs, name='Unknown'):
tic = Event(enable_timing=True)
toc = Event(enable_timing=True)
# warmup
for i in range(args.warmup_num):
func(*inputs)
total_time = 0
tic.record()
for i in range(args.test_num):
o = func(*inputs)
torch.cuda.synchronize()
toc.record()
avg_time = tic.elapsed_time(toc) / args.test_num
# print(
# f'>>> {name: <10} finished {args.test_num} running, avg_time: {avg_time:.6f} ms')
return avg_time
@torch.no_grad()
def test(N, H_in, W_in, M, D, spec=None):
Kh, Kw = 3, 3
remove_center = False
P = Kh * Kw - remove_center
offset_scale = 2.0
pad = 1
dilation = 1
stride = 1
H_out = (H_in + 2 * pad - (dilation * (Kh - 1) + 1)) // stride + 1
W_out = (W_in + 2 * pad - (dilation * (Kw - 1) + 1)) // stride + 1
input = torch.rand(N, H_in, W_in, M*D).cuda()
# print(input.shape)
offset = (torch.rand(N, H_out, W_out, M*P*2).cuda() * 2 - 1)*2
# offset = (torch.rand(N, H_out, W_out, M*P*2).cuda() * 2 - 1)*0
mask_origin = torch.rand(N, H_out, W_out, M, P).cuda() + 1e-5
mask_origin = mask_origin.half()
mask = mask_origin
# mask = torch.nn.functional.softmax(mask_origin, dim=-1)
offset_mask = torch.cat([offset.unflatten(-1, (M, P * 2)), mask_origin.detach()], dim=-1).flatten(-2)
im2col_step = 128
input = input.half()
offset = offset.half()
mask = mask.half()
offset_mask = offset_mask.half()
dcnv3_args = [
input,
offset,
mask,
Kh, Kw, stride, stride, Kh // 2, Kw // 2, dilation, dilation, M, D, offset_scale,
im2col_step, remove_center,
]
output_pytorch = DCNv3Function.apply(*dcnv3_args)
input1 = input.detach()
def pad(om):
padded_zero = int(math.ceil(om.shape[3]/8)*8) - om.shape[3]
padded = torch.zeros(om.shape[0], om.shape[1], om.shape[2], padded_zero).to(om)
return torch.cat([om, padded], dim=-1)
dcnv4_args = [
input1, pad(offset_mask),
Kh, Kw, stride, stride, Kh // 2, Kw // 2, dilation, dilation, M, D, offset_scale,
im2col_step, remove_center,
spec[0], spec[1], 2, None
# 8, 512, 2, 256
]
output_flash_cuda = DCNv4Function.apply(*dcnv4_args)
fwdok = torch.allclose(output_flash_cuda, output_pytorch, rtol=1e-2, atol=1e-3)
max_abs_err = (output_flash_cuda - output_pytorch).abs().max()
max_rel_err = ((output_flash_cuda - output_pytorch).abs() /
(output_pytorch.abs()+ 1e-3)).max()
# print('>>> forward half')
# print(f'* {fwdok} check_forward_equal_with_pytorch_float: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}')
if not fwdok:
print(f"Wrong: {N}x{H_in}x{W_in}x{M}x{D} \t {spec[0]}/{spec[1]}({spec[2]})")
return
# assert(fwdok)
test_args = edict({'warmup_num': 10000, 'test_num': 10000})
exp_time_dcnv4 = speed_test(DCNv4Function.apply, test_args, dcnv4_args, name='exp')
torch.cuda.synchronize()
print(f"{N}x{H_in}x{W_in}x{M}x{D} \t {spec[0]}/{spec[1]}({spec[2]}): {exp_time_dcnv4}")
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--n", type=int)
parser.add_argument("--h", type=int)
parser.add_argument("--w", type=int)
parser.add_argument("--g", type=int)
parser.add_argument("--c", type=int)
parser.add_argument("--dstride", type=int)
parser.add_argument("--blockthread", type=int)
parser.add_argument("--multiplier", type=int)
args = parser.parse_args()
test(args.n, args.h, args.w, args.g, args.c, (args.dstride, args.blockthread, args.multiplier))
# --------------------------------------------------------
# DCNv4
# Copyright (c) 2024 OpenGVLab
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------
from __future__ import absolute_import
from __future__ import print_function
from __future__ import division
import time
import torch
import torch.nn as nn
import math
from torch.autograd import gradcheck
import pandas as pd
from easydict import EasyDict as edict
import argparse
from torch.cuda import Event
from functions import DCNv4Function, DCNv3Function
torch.set_printoptions(threshold=10000)
torch.manual_seed(3)
def speed_test_backward(func, args, inputs, name='Unknown'):
# warmup
# for i in range(args.warmup_num):
# o = func(*inputs)
# o.sum().backward()
total_time = 0
len_input = len(inputs)
for i in range(args.warmup_num + args.test_num):
tic = Event(enable_timing=True)
toc = Event(enable_timing=True)
inputs[0] = inputs[0].detach()
inputs[0].requires_grad = True
if len_input > 1 and isinstance(inputs[1], torch.Tensor):
inputs[1] = inputs[1].detach()
inputs[1].requires_grad = True
if len_input > 2 and isinstance(inputs[2], torch.Tensor):
inputs[2] = inputs[2].detach()
inputs[2].requires_grad = True
o = func(*inputs)
torch.cuda.synchronize()
tic.record()
o.sum().backward()
toc.record()
torch.cuda.synchronize()
_time = tic.elapsed_time(toc)
if i >= args.warmup_num:
total_time += _time
o = o.detach()
# toc.record()
# torch.cuda.synchronize()
avg_time = total_time / args.test_num
#print(
# f'>>> {name: <10} finished {args.test_num} running, avg_time: {avg_time:.6f} ms')
return avg_time
# @torch.no_grad()
def test(N=64, H_in=32, W_in=32, M=4, D=16, spec=None):
"""
64x56x56x128(G=4)
2 64: 3.66
- offset_mask collection write 3.4022
- offset_mask collection 3.1968
"""
Kh, Kw = 3, 3
remove_center = False
P = Kh * Kw - remove_center
offset_scale = 2.0
pad = 1
dilation = 1
stride = 1
H_out = (H_in + 2 * pad - (dilation * (Kh - 1) + 1)) // stride + 1
W_out = (W_in + 2 * pad - (dilation * (Kw - 1) + 1)) // stride + 1
additions = [None, None, spec[0], spec[1], False]
input = torch.rand(N, H_in, W_in, M*D).cuda() * 10
#offset = torch.rand(N, H_out, W_out, M*P*2).cuda() * 0
offset = (torch.rand(N, H_out, W_out, M*P*2).cuda() * 2 - 1)*2
mask_origin = torch.rand(N, H_out, W_out, M, P).cuda() + 1e-5
mask_origin = mask_origin.half()
mask_origin.requires_grad = True
# offset_mask = torch.cat([offset.unflatten(-1, (M, P, 2)), mask_origin.detach().unsqueeze(-1)], dim=-1).flatten(-3)
# mask /= mask.sum(-1, keepdim=True)
# mask = torch.nn.functional.softmax(mask_origin, dim=-1, dtype=torch.float32)
mask = mask_origin
# mask = mask.reshape(N, H_out, W_out, M*P)
# offset_mask = torch.cat([offset.unflatten(-1, (M, P, 2)), mask.detach().unsqueeze(-1)], dim=-1).flatten(-3)
offset_mask = torch.cat([offset.detach().unflatten(-1, (M, P * 2)), mask_origin.detach()], dim=-1).flatten(-2)
im2col_step = 128
input = input.half()
offset = offset.half()
mask = mask.half()
input.requires_grad = True
offset.requires_grad = True
# mask.requires_grad = True
output_pytorch = DCNv3Function.apply(
input,
offset,
mask,
Kh, Kw, stride, stride, Kh // 2, Kw // 2, dilation, dilation, M, D, offset_scale,
im2col_step, remove_center)#.detach().cpu()
(output_pytorch.sum()/10).backward()
def pad(om):
padded_zero = int(math.ceil(om.shape[3]/8)*8) - om.shape[3]
padded = torch.zeros(om.shape[0], om.shape[1], om.shape[2], padded_zero).to(om)
return torch.cat([om, padded], dim=-1)
# value_offset_mask = input.detach()
input1 = input.detach()
input1.requires_grad = True
offset_mask = offset_mask.half()
offset_mask.requires_grad = True
# offset_mask1.requires_grad = True
torch.cuda.profiler.cudart().cudaProfilerStart()
output_flash_cuda = DCNv4Function.apply(
input1, offset_mask,
Kh, Kw, stride, stride, Kh // 2, Kw // 2, dilation, dilation, M, D, offset_scale,
im2col_step, remove_center, *additions)#.detach().cpu()
(output_flash_cuda.sum()/10).backward()
torch.cuda.profiler.cudart().cudaProfilerStop()
input_grad = input.grad
input2_grad = input1.grad
bwdok = torch.allclose(input_grad.float(), input2_grad.float(), rtol=1e-2, atol=1e-3)
rel_err = (input_grad.abs() - input2_grad.abs())/(input_grad.abs()+1e-3)
offset_grad1 = offset.grad
offset_grad2 = offset_mask.grad.reshape(N, H_out, W_out, M, P*3)[..., :P*2].reshape(N, H_out, W_out, M*P*2)
bwdok2 = torch.allclose(offset_grad1.float(), offset_grad2.float(), rtol=1e-2, atol=1e-3)
rel_err = (offset_grad1 - offset_grad2).abs() / (offset_grad1.abs()+1e-3)
mask_grad1 = mask_origin.grad
mask_grad2 = offset_mask.grad.reshape(N, H_out, W_out, M, P*3)[..., P*2:].reshape(N, H_out, W_out, M, P)
bwdok3 = torch.allclose(mask_grad1, mask_grad2, rtol=1e-2, atol=1e-3)
rel_err = (mask_grad1 - mask_grad2).abs() / (mask_grad1.abs()+1e-3)
fwdok = torch.allclose(output_flash_cuda, output_pytorch, rtol=1e-2, atol=1e-3)
max_abs_err = (output_flash_cuda - output_pytorch).abs().max()
max_rel_err = ((output_flash_cuda - output_pytorch).abs() /
(output_pytorch.abs()+ 1e-3)).max()
if not (bwdok and bwdok2 and bwdok3):
print(f"Wrong: {N}x{H_in}x{W_in}x{M}x{D} \t {spec[0]}/{spec[1]}({spec[2]})")
return
# fn_args = [
# input,
# offset,
# mask,
# Kh, Kw, stride, stride, Kh // 2, Kw // 2, dilation, dilation, M, D, offset_scale,
# im2col_step, remove_center
# ]
flash_dcn_fn_args = [
input1,
offset_mask,
Kh, Kw, stride, stride, Kh // 2, Kw // 2, dilation, dilation, M, D, offset_scale,
im2col_step, remove_center, *additions
]
test_args = edict({'warmup_num': 1000, 'test_num': 1000})
try:
exp_time = speed_test_backward(DCNv4Function.apply, test_args, flash_dcn_fn_args, name='exp')
except:
print(f"Wrong: {N}x{H_in}x{W_in}x{M}x{D} \t {spec[0]}/{spec[1]}({spec[2]})")
return
torch.cuda.synchronize()
print(f"{N}x{H_in}x{W_in}x{M}x{D} \t {spec[0]}/{spec[1]}({spec[2]}): {exp_time}")
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--n", type=int)
parser.add_argument("--h", type=int)
parser.add_argument("--w", type=int)
parser.add_argument("--g", type=int)
parser.add_argument("--c", type=int)
parser.add_argument("--dstride", type=int)
parser.add_argument("--blockthread", type=int)
parser.add_argument("--multiplier", type=int)
args = parser.parse_args()
test(args.n, args.h, args.w, args.g, args.c, (args.dstride, args.blockthread, args.multiplier))
import os
def factors(N):
res = []
for i in range(1, N+1):
if N % i == 0:
res.append(i)
return res
if __name__ == '__main__':
BATCH=64
for N, Hin, Win in [(BATCH, 56, 56), (BATCH, 28, 28), (BATCH, 14, 14), (BATCH, 7, 7),
(1, 200, 320), (1, 100, 160), (1, 50, 80), (1, 25, 40), (1, 64, 64)]:
for group_channel in [16, 32, 64]:
for group in [4, 5, 6, 7, 8]:
for d_stride in [1, 2, 4]:
for m in factors(N*Hin*Win):
if m > 64:
break
block_thread = group * (group_channel//d_stride) * m
if block_thread > 1024:
break
cmd = f"python search_dcnv4_bwd.py --n {N} --h {Hin} --w {Win} --g {group} --c {group_channel} --dstride {d_stride} --blockthread {block_thread} --multiplier {m}"
os.system(cmd)
\ No newline at end of file
import os
def factors(N):
res = []
for i in range(1, N+1):
if N % i == 0:
res.append(i)
return res
if __name__ == '__main__':
BATCH=64
for group_channel in [16, 32, 64]:
for group in [4, 5, 6, 7, 8]:
for N, Hin, Win in [(BATCH, 56, 56), (BATCH, 28, 28), (BATCH, 14, 14), (BATCH, 7, 7),
(1, 200, 320), (1, 100, 160), (1, 50, 80), (1, 25, 40), (1, 64, 64)]:
for d_stride in [2, 4, 8, 16]:
for m in factors(N*Hin*Win):
if m > 64:
break
block_thread = group * (group_channel//d_stride) * m
if block_thread > 1024:
break
cmd = f"python search_dcnv4.py --n {N} --h {Hin} --w {Win} --g {group} --c {group_channel} --dstride {d_stride} --blockthread {block_thread} --multiplier {m}"
os.system(cmd)
\ No newline at end of file
python search_dcnv4_engine.py > res.txt
python find_best.py --input res.txt --output table.py
\ No newline at end of file
# --------------------------------------------------------
# DCNv4
# Copyright (c) 2024 OpenGVLab
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------
from __future__ import absolute_import
from __future__ import print_function
from __future__ import division
import time
import torch
import torch.nn as nn
import math
from torch.autograd import gradcheck
import pandas as pd
from easydict import EasyDict as edict
from torch.cuda import Event
# from functions.dcnv3_func import DCNv3Function, dcnv3_core_pytorch
from functions.dcnv4_func import DCNv4Function
torch.set_printoptions(threshold=10000)
H_in, W_in = 56, 56
N, M, D = 64, 4, 32
# H_in, W_in = 28, 28
# N, M, D = 64, 8, 32
# H_in, W_in = 14, 14
# N, M, D = 64, 16, 32
# H_in, W_in = 7, 7
# N, M, D = 64, 32, 32
# H_in, W_in = 8, 8
# N, M, D = 128, 4, 16
Kh, Kw = 3, 3
remove_center = False
P = Kh * Kw - remove_center
offset_scale = 2.0
pad = 1
dilation = 1
stride = 1
H_out = (H_in + 2 * pad - (dilation * (Kh - 1) + 1)) // stride + 1
W_out = (W_in + 2 * pad - (dilation * (Kw - 1) + 1)) // stride + 1
torch.manual_seed(3)
#@torch.no_grad()
def speed_test(func, args, inputs, name='Unknown'):
tic = Event(enable_timing=True)
toc = Event(enable_timing=True)
# warmup
for i in range(args.warmup_num):
func(*inputs)
total_time = 0
tic.record()
for i in range(args.test_num):
o = func(*inputs)
torch.cuda.synchronize()
toc.record()
avg_time = tic.elapsed_time(toc) / args.test_num
print(
f'>>> {name: <10} finished {args.test_num} running, avg_time: {avg_time:.6f} ms')
return avg_time
@torch.no_grad()
def check_forward_equal_with_pytorch_half():
input = torch.rand(N, H_in, W_in, M*D).cuda()
print(input.shape)
offset = (torch.rand(N, H_out, W_out, M*P*2).cuda() * 2 - 1)*10
# offset = (torch.rand(N, H_out, W_out, M*P*2).cuda() * 2 - 1)*0
mask_origin = torch.rand(N, H_out, W_out, M, P).cuda() + 1e-5
mask_origin = mask_origin.half()
mask = mask_origin
# mask = torch.nn.functional.softmax(mask_origin, dim=-1)
offset_mask = torch.cat([offset.unflatten(-1, (M, P * 2)), mask_origin.detach()], dim=-1).flatten(-2)
im2col_step = 128
input = input.half()
offset = offset.half()
mask = mask.half()
offset_mask = offset_mask.half()
input1 = input.detach()
def pad(om):
padded_zero = int(math.ceil(om.shape[3]/8)*8) - om.shape[3]
padded = torch.zeros(om.shape[0], om.shape[1], om.shape[2], padded_zero).to(om)
return torch.cat([om, padded], dim=-1)
dcnv4_args = [
input1, pad(offset_mask),
Kh, Kw, stride, stride, Kh // 2, Kw // 2, dilation, dilation, M, D, offset_scale,
im2col_step, remove_center, 8, 512, 2, 256, True, True,
]
output_flash_cuda = DCNv4Function.apply(*dcnv4_args)
print(f"test success")
# fwdok = torch.allclose(output_flash_cuda, output_pytorch, rtol=1e-2, atol=1e-3)
# max_abs_err = (output_flash_cuda - output_pytorch).abs().max()
# max_rel_err = ((output_flash_cuda - output_pytorch).abs() /
# (output_pytorch.abs()+ 1e-3)).max()
# print('>>> forward half')
# print(f'* {fwdok} check_forward_equal_with_pytorch_float: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}')
# assert(fwdok)
# test_args = edict({'warmup_num': 1000, 'test_num': 1000})
# exp_time_dcnv4 = speed_test(DCNv4Function.apply, test_args, dcnv4_args, name='exp')
# torch.cuda.synchronize()
# results = [{}]
# results[0]['dcnv3_time'] = exp_time_dcnv3
# results[0]['dcnv4_time'] = exp_time_dcnv4
# columns = list(results[0].keys())
# outputs = pd.DataFrame(results, columns=columns)
# with pd.option_context(
# 'display.max_rows', None, 'display.max_columns', None,
# 'display.max_colwidth', None, 'display.width', None,
# 'display.precision', 4, ):
# print(outputs)
if __name__ == '__main__':
check_forward_equal_with_pytorch_half()
# --------------------------------------------------------
# DCNv4
# Copyright (c) 2024 OpenGVLab
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------
from __future__ import absolute_import
from __future__ import print_function
from __future__ import division
import time
import torch
import torch.nn as nn
import math
from torch.autograd import gradcheck
import pandas as pd
from easydict import EasyDict as edict
from torch.cuda import Event
from functions import DCNv4Function, DCNv3Function
torch.set_printoptions(threshold=10000)
H_in, W_in = 56, 56
N, M, D = 64, 4, 32
# H_in, W_in = 28, 28
# N, M, D = 64, 16, 16
# H_in, W_in = 14, 14
# N, M, D = 64, 32, 16
# H_in, W_in = 7, 7
# N, M, D = 64, 64, 16
# H_in, W_in = 8, 8
# N, M, D = 128, 4, 16
Kh, Kw = 3, 3
remove_center = False
P = Kh * Kw - remove_center
offset_scale = 2.0
pad = 1
dilation = 1
stride = 1
H_out = (H_in + 2 * pad - (dilation * (Kh - 1) + 1)) // stride + 1
W_out = (W_in + 2 * pad - (dilation * (Kw - 1) + 1)) // stride + 1
torch.manual_seed(3)
def speed_test_backward(func, args, inputs, name='Unknown'):
# warmup
# for i in range(args.warmup_num):
# o = func(*inputs)
# o.sum().backward()
total_time = 0
len_input = len(inputs)
for i in range(args.warmup_num + args.test_num):
tic = Event(enable_timing=True)
toc = Event(enable_timing=True)
inputs[0] = inputs[0].detach()
inputs[0].requires_grad = True
if len_input > 1 and isinstance(inputs[1], torch.Tensor):
inputs[1] = inputs[1].detach()
inputs[1].requires_grad = True
if len_input > 2 and isinstance(inputs[2], torch.Tensor):
inputs[2] = inputs[2].detach()
inputs[2].requires_grad = True
o = func(*inputs)
torch.cuda.synchronize()
tic.record()
o.sum().backward()
toc.record()
torch.cuda.synchronize()
_time = tic.elapsed_time(toc)
if i >= args.warmup_num:
total_time += _time
o = o.detach()
# toc.record()
# torch.cuda.synchronize()
avg_time = total_time / args.test_num
#print(
# f'>>> {name: <10} finished {args.test_num} running, avg_time: {avg_time:.6f} ms')
return avg_time
# @torch.no_grad()
def check_forward_equal_with_pytorch_half():
"""
64x56x56x128(G=4)
2 64: 3.66
- offset_mask collection write 3.4022
- offset_mask collection 3.1968
"""
additions = [8, 128, 2, 256, False]
input = torch.rand(N, H_in, W_in, M*D).cuda() * 10
print(input.shape)
#offset = torch.rand(N, H_out, W_out, M*P*2).cuda() * 0
offset = (torch.rand(N, H_out, W_out, M*P*2).cuda() * 2 - 1)*2
mask_origin = torch.rand(N, H_out, W_out, M, P).cuda() + 1e-5
mask_origin = mask_origin.half()
mask_origin.requires_grad = True
# offset_mask = torch.cat([offset.unflatten(-1, (M, P, 2)), mask_origin.detach().unsqueeze(-1)], dim=-1).flatten(-3)
# mask /= mask.sum(-1, keepdim=True)
# mask = torch.nn.functional.softmax(mask_origin, dim=-1, dtype=torch.float32)
mask = mask_origin
# mask = mask.reshape(N, H_out, W_out, M*P)
# offset_mask = torch.cat([offset.unflatten(-1, (M, P, 2)), mask.detach().unsqueeze(-1)], dim=-1).flatten(-3)
offset_mask = torch.cat([offset.detach().unflatten(-1, (M, P * 2)), mask_origin.detach()], dim=-1).flatten(-2)
im2col_step = 128
input = input.half()
offset = offset.half()
mask = mask.half()
input.requires_grad = True
offset.requires_grad = True
# mask.requires_grad = True
output_pytorch = DCNv3Function.apply(
input,
offset,
mask,
Kh, Kw, stride, stride, Kh // 2, Kw // 2, dilation, dilation, M, D, offset_scale,
im2col_step, remove_center)#.detach().cpu()
(output_pytorch.sum()/10).backward()
def pad(om):
padded_zero = int(math.ceil(om.shape[3]/8)*8) - om.shape[3]
padded = torch.zeros(om.shape[0], om.shape[1], om.shape[2], padded_zero).to(om)
return torch.cat([om, padded], dim=-1)
# value_offset_mask = input.detach()
input1 = input.detach()
input1.requires_grad = True
offset_mask = offset_mask.half()
offset_mask.requires_grad = True
# offset_mask1.requires_grad = True
torch.cuda.profiler.cudart().cudaProfilerStart()
output_flash_cuda = DCNv4Function.apply(
input1, offset_mask,
Kh, Kw, stride, stride, Kh // 2, Kw // 2, dilation, dilation, M, D, offset_scale,
im2col_step, remove_center, *additions)#.detach().cpu()
(output_flash_cuda.sum()/10).backward()
torch.cuda.profiler.cudart().cudaProfilerStop()
input_grad = input.grad
input2_grad = input1.grad
bwdok = torch.allclose(input_grad.float(), input2_grad.float(), rtol=1e-2, atol=1e-3)
print("bwdok")
print(bwdok)
rel_err = (input_grad.abs() - input2_grad.abs())/(input_grad.abs()+1e-3)
print(rel_err.max())
offset_grad1 = offset.grad
offset_grad2 = offset_mask.grad.reshape(N, H_out, W_out, M, P*3)[..., :P*2].reshape(N, H_out, W_out, M*P*2)
# print(offset_grad1)
# print("====================")
# print(offset_grad2)
bwdok2 = torch.allclose(offset_grad1.float(), offset_grad2.float(), rtol=1e-2, atol=1e-3)
print("bwdok2")
print(bwdok2)
rel_err = (offset_grad1 - offset_grad2).abs() / (offset_grad1.abs()+1e-3)
print(rel_err.max())
mask_grad1 = mask_origin.grad
mask_grad2 = offset_mask.grad.reshape(N, H_out, W_out, M, P*3)[..., P*2:].reshape(N, H_out, W_out, M, P)
bwdok3 = torch.allclose(mask_grad1, mask_grad2, rtol=1e-2, atol=1e-3)
print("bwdok3")
print(bwdok3)
rel_err = (mask_grad1 - mask_grad2).abs() / (mask_grad1.abs()+1e-3)
print(rel_err.max())
fwdok = torch.allclose(output_flash_cuda, output_pytorch, rtol=1e-2, atol=1e-3)
max_abs_err = (output_flash_cuda - output_pytorch).abs().max()
max_rel_err = ((output_flash_cuda - output_pytorch).abs() /
(output_pytorch.abs()+ 1e-3)).max()
print('>>> forward half')
print(f'* {fwdok} check_forward_equal_with_pytorch_float: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}')
fn_args = [
input,
offset,
mask,
Kh, Kw, stride, stride, Kh // 2, Kw // 2, dilation, dilation, M, D, offset_scale,
im2col_step, remove_center
]
flash_dcn_fn_args = [
input1,
offset_mask,
Kh, Kw, stride, stride, Kh // 2, Kw // 2, dilation, dilation, M, D, offset_scale,
im2col_step, remove_center, *additions
]
test_args = edict({'warmup_num': 1000, 'test_num': 1000})
exp_time = speed_test_backward(
DCNv4Function.apply, test_args, flash_dcn_fn_args, name='exp')
exp_time_base = speed_test_backward(
DCNv3Function.apply, test_args, fn_args, name='exp')
results = [{}]
results[0]['time'] = exp_time
results[0]['time_base'] = exp_time_base
columns = list(results[0].keys())
outputs = pd.DataFrame(results, columns=columns)
with pd.option_context(
'display.max_rows', None, 'display.max_columns', None,
'display.max_colwidth', None, 'display.width', None,
'display.precision', 4, ):
print(outputs)
if __name__ == '__main__':
check_forward_equal_with_pytorch_half()
\ No newline at end of file
# ------------------------------------------------------------------------------------------------
# Deformable DETR
# Copyright (c) 2020 SenseTime. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------------------------------
# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
# ------------------------------------------------------------------------------------------------
from __future__ import absolute_import
from __future__ import print_function
from __future__ import division
from easydict import EasyDict as edict
from torch.cuda import Event
import pandas as pd
import time
import torch
import torch.nn as nn
from torch.autograd import gradcheck
from functions import MSDeformAttnFunction, FlashDeformAttnFunction, ms_deform_attn_core_pytorch
# N, M, D = 1, 4, 8
# # Lq, L, P = 2, 2, 2
# # shapes = torch.as_tensor([(6, 4), (3, 2)], dtype=torch.long).cuda()
# Lq, L, P = 1, 2, 8
# shapes = torch.as_tensor([(8, 16), (4, 8)], dtype=torch.long).cuda()
# N, M, D = 1, 8, 32
# # Lq, L, P = 2, 2, 2
# # shapes = torch.as_tensor([(6, 4), (3, 2)], dtype=torch.long).cuda()
# Lq, L, P = 300, 4, 4
# # shapes = torch.as_tensor([(134, 151), (67, 76), (34, 38), (17, 19)], dtype=torch.long).cuda()
# # shapes = torch.as_tensor([(134, 151), (67, 76), (34, 38), (16, 16)], dtype=torch.long).cuda()
# # shapes = torch.as_tensor([(134, 151), (67, 76), (34, 38), (17, 19)], dtype=torch.long).cuda()
# # shapes = torch.as_tensor([(17, 19), (4, 4)], dtype=torch.long).cuda()
# shapes = torch.as_tensor([(100, 151), (50, 76), (25, 38), (13, 19)], dtype=torch.long).cuda()
# # shapes = torch.as_tensor([(110, 151)], dtype=torch.long).cuda()
# B:6
# H:232
# W:400
# G:5
# D: 16
# channels: 80
# kernel: 3 points = 3 * 3
# num_split = 45 = kernel *kernel * G
H = 256
W = 256
N, M, D = 1, 8, 32
Lq, L, P = 100*152, 4, 8
shapes = torch.Tensor([[100, 152], [ 50, 76], [ 25, 38], [ 13, 19]]).long().cuda()
# x = x.reshape([B, H*W, G, D + self.num_split * 3])
# shapes = torch.as_tensor([(H, W)], dtype=torch.long).cuda()
# shapes = torch.as_tensor([(H, W), (H // 2, W // 2)], dtype=torch.long).cuda()
# shapes = torch.as_tensor([(H, W), (H // 2, W // 2), (H // 4, W // 4), (H // 8, W // 8)], dtype=torch.long).cuda()
level_start_index = torch.cat((shapes.new_zeros((1,)), shapes.prod(1).cumsum(0)[:-1]))
S = sum([(H * W).item() for H, W in shapes])
print(S)
def get_reference_points(spatial_shapes, device):
reference_points_list = []
for lvl, (H_, W_) in enumerate(spatial_shapes):
ref_y, ref_x = torch.meshgrid(torch.linspace(0.5, H_ - 0.5, H_, dtype=torch.float32, device=device),
torch.linspace(0.5, W_ - 0.5, W_, dtype=torch.float32, device=device))
ref_y = ref_y.reshape(-1)[None] / (H_)
ref_x = ref_x.reshape(-1)[None] / (W_)
ref = torch.stack((ref_x, ref_y), -1)
reference_points_list.append(ref)
reference_points = torch.cat(reference_points_list, 1)
# reference_points = reference_points[:, :, None] * valid_ratios[:, None]
return reference_points
torch.manual_seed(3)
@torch.no_grad()
def speed_test(func, args, inputs, name='Unknown'):
tic = Event(enable_timing=True)
toc = Event(enable_timing=True)
# warmup
for i in range(args.warmup_num):
func(*inputs)
tic.record()
for i in range(args.test_num):
func(*inputs)
toc.record()
torch.cuda.synchronize()
avg_time = tic.elapsed_time(toc) / args.test_num
print(
f'>>> {name: <10} finished {args.test_num} running, avg_time: {avg_time:.6f} ms')
return avg_time
@torch.no_grad()
def check_forward_equal_with_pytorch_half():
value = torch.rand(N, S, M, D).cuda() * 0.01
# offset = (torch.rand(N, Lq, M, L, P, 2).cuda() * 2 - 1) / 10
sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda()
attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5
sampling_loc_attn = torch.cat([sampling_locations.reshape(N, Lq, M, L*P*2), attention_weights.reshape(N, Lq, M, L*P)], dim=-1)
attention_weights = torch.nn.functional.softmax(attention_weights.flatten(-2, -1), dim=-1).unflatten(-1, (L, P))
im2col_step = 128
flash_fn_args = (
value.half(),
shapes,
level_start_index,
sampling_loc_attn.half(),
im2col_step,
P, 16
)
output_cuda = (
FlashDeformAttnFunction.apply(*flash_fn_args)
.detach()
.cpu()
).double()
fn_args = (
value,
shapes,
level_start_index,
sampling_locations,
attention_weights,
im2col_step,
)
output_pytorch = (
MSDeformAttnFunction.apply(*fn_args)
.detach().double()
.cpu()
)
max_abs_err = (output_cuda - output_pytorch).abs().max()
max_rel_err = ((output_cuda - output_pytorch).abs() / output_pytorch.abs()).max()
fwdok = torch.allclose(output_cuda, output_pytorch, rtol=1e-2, atol=1e-3)
print(
f"* {fwdok} check_forward_equal_with_pytorch_float: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}"
)
test_args = edict({'warmup_num': 1000, 'test_num': 1000})
exp_time_base = speed_test(
MSDeformAttnFunction.apply, test_args, fn_args, name='exp')
exp_time = speed_test(
FlashDeformAttnFunction.apply, test_args, flash_fn_args, name='exp')
results = [{}]
results[0]['time'] = exp_time
results[0]['time_base'] = exp_time_base
columns = list(results[0].keys())
outputs = pd.DataFrame(results, columns=columns)
with pd.option_context(
'display.max_rows', None, 'display.max_columns', None,
'display.max_colwidth', None, 'display.width', None,
'display.precision', 4, ):
print(outputs)
if __name__ == "__main__":
check_forward_equal_with_pytorch_half()
# ------------------------------------------------------------------------------------------------
# Deformable DETR
# Copyright (c) 2020 SenseTime. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------------------------------
# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
# ------------------------------------------------------------------------------------------------
from __future__ import absolute_import
from __future__ import print_function
from __future__ import division
from easydict import EasyDict as edict
from torch.cuda import Event
import pandas as pd
import time
import torch
import torch.nn as nn
from torch.autograd import gradcheck
from functions import MSDeformAttnFunction, ms_deform_attn_core_pytorch, FlashDeformAttnFunction
H = 256
W = 256
N, M, D = 1, 8, 16
Lq, L, P = H * W, 1, 8
# x = x.reshape([B, H*W, G, D + self.num_split * 3])
shapes = torch.as_tensor([(H, W)], dtype=torch.long).cuda()
# shapes = torch.as_tensor([(H, W), (H // 2, W // 2)], dtype=torch.long).cuda()
# shapes = torch.as_tensor([(H, W), (H // 2, W // 2), (H // 4, W // 4), (H // 8, W // 8)], dtype=torch.long).cuda()
H = 256
W = 256
N, M, D = 1, 8, 32
Lq, L, P = 100*152, 4, 8
shapes = torch.Tensor([[100, 152], [ 50, 76], [ 25, 38], [ 13, 19]]).long().cuda()
level_start_index = torch.cat((shapes.new_zeros((1,)), shapes.prod(1).cumsum(0)[:-1]))
S = sum([(H * W).item() for H, W in shapes])
def get_reference_points(spatial_shapes, device):
reference_points_list = []
for lvl, (H_, W_) in enumerate(spatial_shapes):
ref_y, ref_x = torch.meshgrid(torch.linspace(0.5, H_ - 0.5, H_, dtype=torch.float32, device=device),
torch.linspace(0.5, W_ - 0.5, W_, dtype=torch.float32, device=device))
ref_y = ref_y.reshape(-1)[None] / (H_)
ref_x = ref_x.reshape(-1)[None] / (W_)
ref = torch.stack((ref_x, ref_y), -1)
reference_points_list.append(ref)
reference_points = torch.cat(reference_points_list, 1)
# reference_points = reference_points[:, :, None] * valid_ratios[:, None]
return reference_points
torch.manual_seed(3)
@torch.no_grad()
def speed_test(func, args, inputs, name='Unknown'):
tic = Event(enable_timing=True)
toc = Event(enable_timing=True)
# warmup
for i in range(args.warmup_num):
func(*inputs)
tic.record()
for i in range(args.test_num):
func(*inputs)
toc.record()
torch.cuda.synchronize()
avg_time = tic.elapsed_time(toc) / args.test_num
print(
f'>>> {name: <10} finished {args.test_num} running, avg_time: {avg_time:.6f} ms')
return avg_time
def check_forward_equal_with_pytorch_half():
value = torch.rand(N, S, M, D).cuda() * 0.01
offset = (torch.rand(N, Lq, M, L, P, 2).cuda() * 2 - 1) / 10
sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda()
attention_weights_origin = torch.rand(N, Lq, M, L, P).cuda() + 1e-5
attention_weights_origin.requires_grad = True
sampling_loc_attn = torch.cat([sampling_locations.detach().reshape(N, Lq, M, L*P*2), attention_weights_origin.detach().reshape(N, Lq, M, L*P)], dim=-1)
attention_weights = torch.nn.functional.softmax(attention_weights_origin.flatten(-2, -1), dim=-1).unflatten(-1, (L, P))
im2col_step = 128
value.requires_grad = True
sampling_loc_attn.requires_grad = True
output_cuda = (
FlashDeformAttnFunction.apply(
value.float(),
shapes,
level_start_index,
sampling_loc_attn.float(),
im2col_step,
)
)
(output_cuda.float().sum()/10).backward()
value1 = value.detach()
value1.requires_grad = True
sampling_locations.requires_grad = True
#attention_weights.requires_grad = True
output_pytorch = (
ms_deform_attn_core_pytorch(value1, shapes, sampling_locations, attention_weights)
)
(output_pytorch.sum()/10).backward()
max_abs_err = (output_cuda.float() - output_pytorch).abs().max()
max_rel_err = ((output_cuda.float() - output_pytorch).abs() / output_pytorch.abs()).max()
fwdok = torch.allclose(output_cuda.float(), output_pytorch, rtol=1e-2, atol=1e-3)
print(fwdok)
print(max_abs_err, max_rel_err)
#exit()
bwdok1 = torch.allclose(value.grad, value1.grad, rtol=1e-2, atol=1e-3)
print(bwdok1)
# rel_err = (sampling_locations.grad - sampling_loc_attn.grad[..., :L*P*2].reshape(*sampling_locations.shape)).abs()/(sampling_locations.grad.abs()+1e-3)
# print(rel_err.max())
locgrad1 = sampling_locations.grad
locgrad2 = sampling_loc_attn.grad[..., :L*P*2].reshape(*sampling_locations.shape)
bwdok2 = torch.allclose(locgrad1, locgrad2, rtol=1e-2, atol=1e-3)
print(bwdok2)
rel_err = (locgrad1 - locgrad2).abs()/(locgrad1.abs()+1e-3)
print(rel_err.max())
attngrad1 = attention_weights_origin.grad
attngrad2 = sampling_loc_attn.grad[..., L*P*2:].reshape(*attention_weights_origin.shape)
bwdok3 = torch.allclose(locgrad1, locgrad2, rtol=1e-2, atol=1e-3)
print(bwdok3)
rel_err = (attngrad1 - attngrad2).abs()/(attngrad1.abs()+1e-3)
print(rel_err.max())
exit()
#exit()
# pdb.set_trace()
max_abs_err = (output_cuda - output_pytorch).abs().max()
max_rel_err = ((output_cuda - output_pytorch).abs() / output_pytorch.abs()).max()
fwdok = torch.allclose(output_cuda, output_pytorch, rtol=1e-2, atol=1e-3)
print(
f"* {fwdok} check_forward_equal_with_pytorch_float: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}"
)
fn_args = (
value,
shapes,
level_start_index,
sampling_locations,
attention_weights,
im2col_step,
)
flash_dcn_fn_args = (
value.half(),
shapes,
level_start_index,
sampling_loc_attn.half(),
im2col_step,
)
test_args = edict({'warmup_num': 50, 'test_num': 100})
exp_time = speed_test(
FlashMSDeformAttnFunction.apply, test_args, flash_dcn_fn_args, name='exp')
exp_time_base = speed_test(
MSDeformAttnFunction.apply, test_args, fn_args, name='exp')
results = [{}]
results[0]['time'] = exp_time
results[0]['time_base'] = exp_time_base
columns = list(results[0].keys())
outputs = pd.DataFrame(results, columns=columns)
with pd.option_context(
'display.max_rows', None, 'display.max_columns', None,
'display.max_colwidth', None, 'display.width', None,
'display.precision', 4, ):
print(outputs)
if __name__ == "__main__":
check_forward_equal_with_pytorch_half()
\ No newline at end of file
# ------------------------------------------------------------------------------------------------
# Deformable Convolution v4
# Copyright (c) 2024 OpenGVLab
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------------------------------
# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
# ------------------------------------------------------------------------------------------------
import os
import glob
import torch
# 导入打包相关库
from setuptools import find_packages, setup
from torch.utils.cpp_extension import CUDA_HOME, CppExtension, CUDAExtension
# 定义获取扩展的函数(保持原样,供非打包模式使用)
def get_extensions():
this_dir = os.path.dirname(os.path.abspath(__file__))
extensions_dir = os.path.join(this_dir, "src")
main_file = glob.glob(os.path.join(extensions_dir, "*.cpp"))
source_cpu = glob.glob(os.path.join(extensions_dir, "cpu", "*.cpp"))
source_cuda = glob.glob(os.path.join(extensions_dir, "cuda", "*.cu"))
sources = main_file + source_cpu
extension = CppExtension
extra_compile_args = {"cxx": []}
define_macros = []
if torch.cuda.is_available() and CUDA_HOME is not None:
extension = CUDAExtension
sources += source_cuda
define_macros += [("WITH_CUDA", None)]
extra_compile_args["nvcc"] = [
"-DCUDA_HAS_FP16=1",
"-D__CUDA_NO_HALF_OPERATORS__",
"-D__CUDA_NO_HALF_CONVERSIONS__",
"-D__CUDA_NO_HALF2_OPERATORS__",
"-O3",
]
else:
raise NotImplementedError('Cuda is not available')
sources = [os.path.join(extensions_dir, s) for s in sources]
include_dirs = [extensions_dir]
ext_modules = [
extension(
"DCNv4.ext", # 注意:这里保持原模块名,方便后面替换
sources,
include_dirs=include_dirs,
define_macros=define_macros,
extra_compile_args=extra_compile_args,
)
]
return ext_modules
# --- 核心修改逻辑 ---
# 检查是否是构建 Wheel 的模式
# 如果是构建 Wheel,我们不编译,而是将现有的 .so 作为包数据处理
# 注意:setuptools 打包扩展模块和打包数据文件的逻辑是冲突的,所以我们需要在构建 Wheel 时禁用 ext_modules
if __name__ == "__main__":
# 检查环境变量,决定是否跳过编译
# 你也可以直接写一个布尔值,或者检查某个文件是否存在
build_so = int(os.getenv("DCNv4_BUILD_SO", "0"))
# 准备参数
kwargs = {
"name": "DCNv4",
"version": "1.0.0.post2",
"author": "Yuwen Xiong, Feng Wang",
"url": "",
"description": "PyTorch Wrapper for CUDA Functions of DCNv4",
"packages": ['DCNv4', 'DCNv4/functions', 'DCNv4/modules'],
"package_data": {
"DCNv4": ["ext.so"], # 假设 ext.so 生成在 DCNv4 目录下
# "DCNv4": ["ext.cpython-310-x86_64-linux-gnu.so"], # 假设 ext.so 生成在 DCNv4 目录下
},
"cmdclass": {"build_ext": torch.utils.cpp_extension.BuildExtension},
# 确保生成正确的 .dist-info
"zip_safe": False,
# 添加以下参数来避免生成 .egg-info 在当前目录
"options": {
'egg_info': {
'egg_base': '/tmp' # 将 egg-info 生成到临时目录
}
},
}
if build_so:
# 正常开发模式,进行编译
kwargs["ext_modules"] = get_extensions()
else:
print("=== BUILD WHEEL MODE: Skipping compilation, using existing ext.so ===")
# 在构建 Wheel 时,不要传入 ext_modules
# 我们依赖 MANIFEST.in 或 package_data 将 .so 文件包含进去
# 但是 setuptools 的 bdist_wheel 默认会忽略 .so,所以我们需要确保 .so 在包目录里
# 这里我们不传入 ext_modules,而是依靠外部脚本或 MANIFEST.in
# 更简单的方法:直接在 setup 里不写 ext_modules,确保 .so 已经在 DCNv4/ 目录下
kwargs["ext_modules"] = [] # 强制不编译
setup(**kwargs)
\ No newline at end of file
#ifndef FMSDACOMMON
#define FMSDACOMMON
#include <algorithm>
#include <cstdio>
#include <cstring>
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <THC/THCAtomics.cuh>
#include <ATen/ATen.h>
#include <ATen/OpMathType.h>
#include <ATen/cuda/CUDAContext.h>
#include <cooperative_groups.h>
#include <cooperative_groups/memcpy_async.h>
#include <cuda.h>
#include <cuda_fp16.h>
#include <cuda_runtime.h>
#ifdef _WIN32
#define uint unsigned int
#endif
constexpr int kWarpSize = 32;
#define opmath_t at::opmath_type<scalar_t>
inline int GET_BLOCKS(const int N, const int num_threads) {
return (N + num_threads - 1) / num_threads;
}
#define CUDA_KERNEL_LOOP(i, n) \
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \
i += blockDim.x * gridDim.x)
inline bool check_backward_warpp(int d_stride, int D){
int n_group_threads = D / d_stride;
return (n_group_threads <= kWarpSize) && (kWarpSize % n_group_threads == 0);
}
template <typename scalar_t, typename transfer_t, int c_per_thread>
__device__ void ms_deform_attn_im2col_bilinear(
opmath_t out_reg_array[], const scalar_t *&p_value, const int &height,
const int &width, const opmath_t &h_px, const opmath_t &w_px,
const opmath_t &attn, const int &w_stride, const int &base_ptr) {
const int h_low = floor(h_px);
const int w_low = floor(w_px);
const int h_high = h_low + 1;
const int w_high = w_low + 1;
const opmath_t lh = h_px - h_low;
const opmath_t lw = w_px - w_low;
const opmath_t hh = 1 - lh;
const opmath_t hw = 1 - lw;
const opmath_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
const int h_stride = width * w_stride;
const int h_low_ptr_offset = h_low * h_stride;
const int h_high_ptr_offset = h_low_ptr_offset + h_stride;
const int w_low_ptr_offset = w_low * w_stride;
const int w_high_ptr_offset = w_low_ptr_offset + w_stride;
int idx1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr;
int idx2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr;
int idx3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr;
int idx4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr;
scalar_t v1_array[c_per_thread] = {0.};
scalar_t v2_array[c_per_thread] = {0.};
scalar_t v3_array[c_per_thread] = {0.};
scalar_t v4_array[c_per_thread] = {0.};
if (h_low >= 0 && w_low >= 0) {
auto p1 = p_value + idx1;
*(transfer_t *)(v1_array) = *(transfer_t *)(p1);
}
if (h_low >= 0 && w_high < width) {
auto p2 = p_value + idx2;
*(transfer_t *)(v2_array) = *(transfer_t *)(p2);
}
if (h_high < height && w_low >= 0) {
auto p3 = p_value + idx3;
*(transfer_t *)(v3_array) = *(transfer_t *)(p3);
}
if (h_high < height && w_high < width) {
auto p4 = p_value + idx4;
*(transfer_t *)(v4_array) = *(transfer_t *)(p4);
}
#pragma unroll
for (int i = 0; i < c_per_thread; i++) {
out_reg_array[i] +=
(opmath_t)attn *
(w1 * (opmath_t)v1_array[i] + w2 * (opmath_t)v2_array[i] +
w3 * (opmath_t)v3_array[i] + w4 * (opmath_t)v4_array[i]);
}
}
template <typename scalar_t, typename transfer_t, int c_per_thread>
__device__ void ms_deform_attn_col2im_bilinear(
const scalar_t *&p_value, const int &height, const int &width,
const opmath_t &h_px, const opmath_t &w_px, const opmath_t &attn,
const int &w_stride, const int &base_ptr, const opmath_t offset_scale_h,
const opmath_t offset_scale_w, const scalar_t *&top_grad,
opmath_t *&grad_im, opmath_t *grad_offset) {
const int h_low = floor(h_px);
const int w_low = floor(w_px);
const int h_high = h_low + 1;
const int w_high = w_low + 1;
const opmath_t lh = h_px - h_low;
const opmath_t lw = w_px - w_low;
const opmath_t hh = 1 - lh;
const opmath_t hw = 1 - lw;
const opmath_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
scalar_t _top_grad_array[c_per_thread] = {0.};
*(transfer_t *)(_top_grad_array) = *(transfer_t *)(top_grad);
opmath_t top_grad_array[c_per_thread] = {0.};
for (int i = 0; i < c_per_thread; ++i) {
top_grad_array[i] = (opmath_t)(_top_grad_array[i]);
}
const int h_stride = width * w_stride;
const int h_low_ptr_offset = h_low * h_stride;
const int h_high_ptr_offset = h_low_ptr_offset + h_stride;
const int w_low_ptr_offset = w_low * w_stride;
const int w_high_ptr_offset = w_low_ptr_offset + w_stride;
int idx1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr;
int idx2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr;
int idx3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr;
int idx4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr;
scalar_t v1_array[c_per_thread] = {0.};
scalar_t v2_array[c_per_thread] = {0.};
scalar_t v3_array[c_per_thread] = {0.};
scalar_t v4_array[c_per_thread] = {0.};
opmath_t grad_h_weight[c_per_thread] = {0.};
opmath_t grad_w_weight[c_per_thread] = {0.};
if (h_low >= 0 && w_low >= 0) {
auto p1 = p_value + idx1;
*(transfer_t *)(v1_array) = *(transfer_t *)(p1);
#pragma unroll
for (int i = 0; i < c_per_thread; ++i) {
grad_h_weight[i] -= hw * v1_array[i];
grad_w_weight[i] -= hh * v1_array[i];
atomicAdd(grad_im + idx1 + i, top_grad_array[i] * attn * w1);
}
}
if (h_low >= 0 && w_high < width) {
auto p2 = p_value + idx2;
*(transfer_t *)(v2_array) = *(transfer_t *)(p2);
#pragma unroll
for (int i = 0; i < c_per_thread; ++i) {
grad_h_weight[i] -= lw * v2_array[i];
grad_w_weight[i] += hh * v2_array[i];
atomicAdd(grad_im + idx2 + i, top_grad_array[i] * attn * w2);
}
}
if (h_high < height && w_low >= 0) {
auto p3 = p_value + idx3;
*(transfer_t *)(v3_array) = *(transfer_t *)(p3);
#pragma unroll
for (int i = 0; i < c_per_thread; ++i) {
grad_h_weight[i] += hw * v3_array[i];
grad_w_weight[i] -= lh * v3_array[i];
atomicAdd(grad_im + idx3 + i, top_grad_array[i] * attn * w3);
}
}
if (h_high < height && w_high < width) {
auto p4 = p_value + idx4;
*(transfer_t *)(v4_array) = *(transfer_t *)(p4);
#pragma unroll
for (int i = 0; i < c_per_thread; ++i) {
grad_h_weight[i] += lw * v4_array[i];
grad_w_weight[i] += lh * v4_array[i];
atomicAdd(grad_im + idx4 + i, top_grad_array[i] * attn * w4);
}
}
opmath_t _grad_offset_x = 0;
opmath_t _grad_offset_y = 0;
#pragma unroll
for (int i = 0; i < c_per_thread; ++i) {
_grad_offset_x +=
grad_w_weight[i] * top_grad_array[i]; // channel aware term
_grad_offset_y +=
grad_h_weight[i] * top_grad_array[i]; // channel aware term
}
_grad_offset_x *= (offset_scale_w * attn); // channel shared term
_grad_offset_y *= (offset_scale_h * attn); // channel shared term
*grad_offset = _grad_offset_x;
*(grad_offset + 1) = _grad_offset_y;
opmath_t current_val;
opmath_t _grad_offset_z = 0;
#pragma unroll
for (int i = 0; i < c_per_thread; i++) {
current_val = (opmath_t)(w1 * v1_array[i] + w2 * v2_array[i] +
w3 * v3_array[i] + w4 * v4_array[i]);
_grad_offset_z += current_val * top_grad_array[i];
}
*(grad_offset + 2) = _grad_offset_z;
}
#endif
#include <algorithm>
#include <cstdio>
#include <cstring>
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <THC/THCAtomics.cuh>
#include <ATen/ATen.h>
#include <ATen/OpMathType.h>
#include <ATen/cuda/CUDAContext.h>
#include <cooperative_groups.h>
#include <cooperative_groups/memcpy_async.h>
#include <cuda.h>
#include <cuda_fp16.h>
#include <cuda_runtime.h>
#include "common.h"
template <typename scalar_t, int d_stride, typename transfer_t, int L, int K,
bool softmax>
__global__ void backward_kernel_dcn(
const scalar_t *p_value, const scalar_t *p_offset,
const scalar_t *grad_output, const int G, const int D, const int Q,
const int kernel_h, const int kernel_w, const int stride_h,
const int stride_w, const int pad_h, const int pad_w, const int dilation_h,
const int dilation_w, const int height_in, const int width_in,
const int height_out, const int width_out, const opmath_t offset_scale,
const int remove_center, const int block_multiplier, opmath_t *grad_im,
opmath_t *grad_offset, const int padded_offset_dim) {
extern __shared__ char _s[];
const int &qi = (blockIdx.x * block_multiplier % Q) + threadIdx.z;
const int &bi = blockIdx.x * block_multiplier / Q;
const int &di_s = threadIdx.x * d_stride;
const int &gi = threadIdx.y;
constexpr int li = 0;
opmath_t *const cache_g_mask_before_softmax = (opmath_t *)(_s); // mG x K
opmath_t *const cache_grad_offset =
(opmath_t *)(cache_g_mask_before_softmax +
block_multiplier * G * K); // mG x blockDim.x x 3
opmath_t *const p_mask_shm =
(opmath_t *)(cache_grad_offset + block_multiplier * G * blockDim.x * 3) +
(threadIdx.z * G + gi) * K;
const scalar_t *p_offset_ptr = p_offset + (bi*Q + qi)*padded_offset_dim + gi*K*3;
const int mask_length = K;
const int num_thread = (D / d_stride);
const int num_iter = mask_length / num_thread;
const int remainder = mask_length - num_iter * num_thread;
const scalar_t *top_grad = grad_output + ((bi * Q + qi) * G + gi) * D + di_s;
__syncthreads();
for (int i = 0; i < num_iter; i++) {
*(p_mask_shm + num_thread * i + threadIdx.x) =
*(scalar_t *)(p_offset_ptr + K * 2 + num_thread * i + threadIdx.x);
}
if (remainder > 0 && threadIdx.x < remainder) {
*(p_mask_shm + num_thread * num_iter + threadIdx.x) = *(
scalar_t *)(p_offset_ptr + K * 2 + num_thread * num_iter + threadIdx.x);
}
if (softmax) {
__syncthreads();
// transfer offset from global memory to shared memory >
// Calculate softmax over L and K
if (threadIdx.x == 0) { // gi != 0, di = 0, li = 0
opmath_t softmax_max = -1e100;
opmath_t softmax_sum = 0.0;
// get max
for (int j = 0; j < K; j++) {
softmax_max = max(softmax_max, p_mask_shm[j]);
}
// get sumexp
for (int j = 0; j < K; j++) {
opmath_t exp_results = exp(p_mask_shm[j] - softmax_max);
p_mask_shm[j] = exp_results;
softmax_sum += exp_results;
}
// normalize
for (int j = 0; j < K; j++) {
p_mask_shm[j] /= softmax_sum;
}
}
__syncthreads();
}
int offset_idx = 0;
int mask_idx = 0;
const int w_stride = G * D;
const int base_ptr = gi * D + di_s;
const scalar_t *p_value_ptr =
p_value + (bi * (height_in * width_in)) * (G * D);
opmath_t *grad_im_ptr = grad_im + (bi * (height_in * width_in)) * (G * D);
const int p0_w = ((dilation_w * (kernel_w - 1)) >> 1) - pad_w +
(qi % width_out) * stride_w;
const int p0_h = ((dilation_h * (kernel_h - 1)) >> 1) - pad_h +
(qi / width_out) * stride_h;
const opmath_t p0_w_ =
p0_w - ((dilation_w * (kernel_w - 1)) >> 1) * offset_scale;
const opmath_t p0_h_ =
p0_h - ((dilation_h * (kernel_h - 1)) >> 1) * offset_scale;
const int center_h = kernel_h / 2;
const int center_w = kernel_w / 2;
grad_offset += (bi*Q + qi)*padded_offset_dim + gi*K*3;
opmath_t *grad_offset_softmax = grad_offset + K * 2;
int cache_grad_off_idx =
((threadIdx.z * G + threadIdx.y) * blockDim.x + threadIdx.x) * 3;
for (int i = 0; i < kernel_w; ++i) {
for (int j = 0; j < kernel_h; ++j) {
if (i != center_w || j != center_h || !remove_center) {
const opmath_t w_im =
p0_w_ + (i * dilation_w + (opmath_t)p_offset_ptr[offset_idx]) *
offset_scale;
const opmath_t h_im =
p0_h_ + (j * dilation_h + (opmath_t)p_offset_ptr[offset_idx + 1]) *
offset_scale;
const opmath_t attn = p_mask_shm[mask_idx];
cache_grad_offset[cache_grad_off_idx] = 0;
cache_grad_offset[cache_grad_off_idx + 1] = 0;
cache_grad_offset[cache_grad_off_idx + 2] = 0;
if (h_im > -1 && w_im > -1 && h_im < height_in && w_im < width_in) {
ms_deform_attn_col2im_bilinear<scalar_t, transfer_t, d_stride>(
p_value_ptr, height_in, width_in, h_im, w_im, attn, w_stride,
base_ptr, offset_scale, offset_scale, top_grad, grad_im_ptr,
cache_grad_offset + cache_grad_off_idx);
}
// aggregated across different channel for offset
__syncthreads();
if (threadIdx.x == 0) { //
int _didx = (threadIdx.z * G + threadIdx.y) * blockDim.x * 3;
opmath_t _grad_w = cache_grad_offset[_didx],
_grad_h = cache_grad_offset[_didx + 1],
_grad_a = cache_grad_offset[_didx + 2];
for (int c_id = 1; c_id < blockDim.x; ++c_id) {
_grad_w += cache_grad_offset[_didx + 3 * c_id];
_grad_h += cache_grad_offset[_didx + 3 * c_id + 1];
_grad_a += cache_grad_offset[_didx + 3 * c_id + 2];
}
*(grad_offset) = _grad_w; // B x H x W x G x L x K x 3
*(grad_offset + 1) = _grad_h; // B x H x W x G x L x K x 3
if (softmax) {
cache_g_mask_before_softmax[(threadIdx.z * G + threadIdx.y) * K +
mask_idx] = _grad_a * attn;
}
else{
grad_offset_softmax[mask_idx] = _grad_a;
}
}
__syncthreads();
offset_idx += 2;
mask_idx += 1;
grad_offset += 2;
}
}
}
// backward for softmax
if(softmax){
if (threadIdx.x == 0) {
const opmath_t* group_g_mask = cache_g_mask_before_softmax + (threadIdx.z*G + threadIdx.y)*K;
#pragma unroll
for (int i = 0; i < K; ++i) {
opmath_t sum = 0.;
for (int j = 0; j < K; ++j) {
sum += group_g_mask[j]; // dL/di * di/dj
}
*(grad_offset_softmax) = group_g_mask[i] - p_mask_shm[i] * sum;
grad_offset_softmax += 1;
}
}
__syncthreads();
}
}
template <typename scalar_t, int d_stride, typename transfer_t, int L, int K,
bool softmax>
__global__ void backward_kernel_dcn_warp_primitive(
const scalar_t *p_value, const scalar_t *p_offset,
const scalar_t *grad_output, const int G, const int D, const int Q,
const int kernel_h, const int kernel_w, const int stride_h,
const int stride_w, const int pad_h, const int pad_w, const int dilation_h,
const int dilation_w, const int height_in, const int width_in,
const int height_out, const int width_out, const opmath_t offset_scale,
const int remove_center, const int block_multiplier, opmath_t *grad_im,
opmath_t *grad_offset, const int padded_offset_dim) {
extern __shared__ char _s[];
const int &qi = (blockIdx.x * block_multiplier % Q) + threadIdx.z;
const int &bi = blockIdx.x * block_multiplier / Q;
const int &di_s = threadIdx.x * d_stride;
const int &gi = threadIdx.y;
constexpr int li = 0;
const int tid = (threadIdx.z * blockDim.y + threadIdx.y)*blockDim.x + threadIdx.x;
const int lane_id = tid % kWarpSize;
// find the position of current group in the current warp
const int group_per_warp = kWarpSize / blockDim.x;
const int group_in_warp_id = (threadIdx.z * G + threadIdx.y) % group_per_warp;
const unsigned lane_mask = ((1 << blockDim.x) - 1) << (group_in_warp_id * blockDim.x);
opmath_t *const p_mask_shm = (opmath_t *)(_s) + (threadIdx.z * G + gi) * K;
opmath_t *cache_g_mask_before_softmax = (opmath_t *)((opmath_t *)(_s) + block_multiplier * G * K) +
(threadIdx.z*G+gi)*K; // only used by threadIdx.x = 0
const scalar_t *p_offset_ptr = p_offset + (bi*Q + qi)*padded_offset_dim + gi*K*3;
const int mask_length = K;
const int num_thread = (D / d_stride);
const int num_iter = mask_length / num_thread;
const int remainder = mask_length - num_iter * num_thread;
const scalar_t *top_grad = grad_output + ((bi * Q + qi) * G + gi) * D + di_s;
__syncthreads();
for (int i = 0; i < num_iter; i++) {
*(p_mask_shm + num_thread * i + threadIdx.x) =
*(scalar_t *)(p_offset_ptr + K * 2 + num_thread * i + threadIdx.x);
}
if (remainder > 0 && threadIdx.x < remainder) {
*(p_mask_shm + num_thread * num_iter + threadIdx.x) = *(
scalar_t *)(p_offset_ptr + K * 2 + num_thread * num_iter + threadIdx.x);
}
if (softmax) {
__syncthreads();
// transfer offset from global memory to shared memory >
// Calculate softmax over L and K
if (threadIdx.x == 0) { // gi != 0, di = 0, li = 0
opmath_t softmax_max = -1e100;
opmath_t softmax_sum = 0.0;
// get max
for (int j = 0; j < K; j++) {
softmax_max = max(softmax_max, p_mask_shm[j]);
}
// get sumexp
for (int j = 0; j < K; j++) {
opmath_t exp_results = exp(p_mask_shm[j] - softmax_max);
p_mask_shm[j] = exp_results;
softmax_sum += exp_results;
}
// normalize
for (int j = 0; j < K; j++) {
p_mask_shm[j] /= softmax_sum;
}
}
__syncthreads();
}
int offset_idx = 0;
int mask_idx = 0;
const int w_stride = G * D;
const int base_ptr = gi * D + di_s;
const scalar_t *p_value_ptr =
p_value + (bi * (height_in * width_in)) * (G * D);
opmath_t *grad_im_ptr = grad_im + (bi * (height_in * width_in)) * (G * D);
const int p0_w = ((dilation_w * (kernel_w - 1)) >> 1) - pad_w +
(qi % width_out) * stride_w;
const int p0_h = ((dilation_h * (kernel_h - 1)) >> 1) - pad_h +
(qi / width_out) * stride_h;
const opmath_t p0_w_ =
p0_w - ((dilation_w * (kernel_w - 1)) >> 1) * offset_scale;
const opmath_t p0_h_ =
p0_h - ((dilation_h * (kernel_h - 1)) >> 1) * offset_scale;
const int center_h = kernel_h / 2;
const int center_w = kernel_w / 2;
grad_offset += (bi * Q + qi)*padded_offset_dim + gi*K*3;
opmath_t *grad_offset_softmax = grad_offset + K * 2;
int cache_grad_off_idx =
((threadIdx.z * G + threadIdx.y) * blockDim.x + threadIdx.x) * 3;
opmath_t reg_grad_offset[3] = {0.};
for (int i = 0; i < kernel_w; ++i) {
for (int j = 0; j < kernel_h; ++j) {
if (i != center_w || j != center_h || !remove_center) {
const opmath_t w_im =
p0_w_ + (i * dilation_w + (opmath_t)p_offset_ptr[offset_idx]) *
offset_scale;
const opmath_t h_im =
p0_h_ + (j * dilation_h + (opmath_t)p_offset_ptr[offset_idx + 1]) *
offset_scale;
const opmath_t attn = p_mask_shm[mask_idx];
reg_grad_offset[0] = 0;
reg_grad_offset[1] = 0;
reg_grad_offset[2] = 0;
if (h_im > -1 && w_im > -1 && h_im < height_in && w_im < width_in) {
ms_deform_attn_col2im_bilinear<scalar_t, transfer_t, d_stride>(
p_value_ptr, height_in, width_in, h_im, w_im, attn, w_stride,
base_ptr, offset_scale, offset_scale, top_grad, grad_im_ptr,
reg_grad_offset);
}
// aggregated across different channel for offset
for (uint32_t offset = blockDim.x>>1; offset > 0; offset >>= 1){
reg_grad_offset[0] += __shfl_down_sync(lane_mask, reg_grad_offset[0], offset);
reg_grad_offset[1] += __shfl_down_sync(lane_mask, reg_grad_offset[1], offset);
reg_grad_offset[2] += __shfl_down_sync(lane_mask, reg_grad_offset[2], offset);
}
if (threadIdx.x == 0) { //
*(grad_offset) = reg_grad_offset[0]; // B x H x W x G x L x K x 3
*(grad_offset + 1) = reg_grad_offset[1]; // B x H x W x G x L x K x 3
if (softmax) {
cache_g_mask_before_softmax[mask_idx] = reg_grad_offset[2] * attn;
}
else{
grad_offset_softmax[mask_idx] = reg_grad_offset[2];
}
}
offset_idx += 2;
mask_idx += 1;
grad_offset += 2;
}
}
}
// backward for softmax
if(softmax){
if (threadIdx.x == 0) {
opmath_t sum = 0.;
#pragma unroll
for (int i=0; i < K; ++i){
sum += cache_g_mask_before_softmax[i];
}
#pragma unroll
for (int i = 0; i < K; ++i) {
*(grad_offset_softmax) = cache_g_mask_before_softmax[i] - p_mask_shm[i] * sum;
grad_offset_softmax += 1;
}
}
}
}
template <typename scalar_t, typename stride_type, int d_stride>
void _dcnv4_col2im_cuda(
cudaStream_t stream,
const scalar_t *value, // B, H * W, (G * D)
const scalar_t *p_offset, // B, H * W, (G*K*3)
const scalar_t *grad_output, // B, H_out*W_out, G * D
const int kernel_h, const int kernel_w, const int stride_h,
const int stride_w, const int pad_h, const int pad_w, const int dilation_h,
const int dilation_w, const int G, const int D, const int B,
const int height_in, const int width_in, const int height_out,
const int width_out, const opmath_t offset_scale, const int remove_center,
opmath_t *grad_im, opmath_t *grad_offset, const int block_thread,
const bool softmax, const int padded_offset_dim) {
constexpr int L = 1;
auto kernel =
backward_kernel_dcn_warp_primitive<scalar_t, d_stride, stride_type, 1, 9, false>;
int N = height_in * width_in;
int Q = height_out * width_out;
int K = kernel_h * kernel_w;
if (remove_center) {
K -= 1;
}
if (softmax) {
switch (K) {
case 9:
if(check_backward_warpp(d_stride, D)){
kernel = backward_kernel_dcn_warp_primitive<scalar_t, d_stride, stride_type, 1, 9, true>;
}
else{
kernel = backward_kernel_dcn<scalar_t, d_stride, stride_type, 1, 9, true>;
}
break;
case 8:
if(check_backward_warpp(d_stride, D)){
kernel = backward_kernel_dcn_warp_primitive<scalar_t, d_stride, stride_type, 1, 8, true>;
}
else {
kernel = backward_kernel_dcn<scalar_t, d_stride, stride_type, 1, 8, true>;
}
break;
default:
printf("K=%ld\n", K);
throw std::invalid_argument("invalid kernel shape");
}
} else {
switch (K) {
case 9:
if(check_backward_warpp(d_stride, D)){
kernel = backward_kernel_dcn_warp_primitive<scalar_t, d_stride, stride_type, 1, 9, false>;
}
else{
kernel = backward_kernel_dcn<scalar_t, d_stride, stride_type, 1, 9, false>;
}
break;
case 8:
if(check_backward_warpp(d_stride, D)){
kernel = backward_kernel_dcn_warp_primitive<scalar_t, d_stride, stride_type, 1, 8, false>;
}
else {
kernel = backward_kernel_dcn<scalar_t, d_stride, stride_type, 1, 8, false>;
}
break;
default:
printf("K=%ld\n", K);
throw std::invalid_argument("invalid kernel shape");
}
}
const int block_multiplier = block_thread / (D / d_stride) / G;
assert((B*Q) % block_multiplier == 0);
dim3 num_blocks(B*Q / block_multiplier);
dim3 num_threads(D / d_stride, G, block_multiplier);
const int blockdimX = D / d_stride;
int shm_size = sizeof(opmath_t) * (G * block_multiplier * K) * 2;
if(!check_backward_warpp(d_stride, D)){
shm_size = sizeof(opmath_t) * ((G * block_multiplier * K) * 2 + G * block_multiplier * blockdimX * 3);
}
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize,
shm_size);
kernel<<<num_blocks, num_threads, shm_size, stream>>>(
value, p_offset, grad_output, G, D, Q, kernel_h, kernel_w, stride_h,
stride_w, pad_h, pad_w, dilation_h, dilation_w, height_in, width_in,
height_out, width_out, offset_scale, remove_center, block_multiplier,
grad_im, grad_offset, padded_offset_dim);
cudaError_t err = cudaGetLastError();
if (err != cudaSuccess) {
printf("error in dcnv4_im2col_cuda: %s\n", cudaGetErrorString(err));
printf("launch arguments: gridDim=(%d, %d, %d), blockDim=(%d, %d, %d), "
"shm_size=%d\n\n",
num_blocks.x, num_blocks.y, num_blocks.z, num_threads.x,
num_threads.y, num_threads.z, shm_size);
AT_ASSERTM(false, "kernel launch error");
}
}
template <typename scalar_t>
void dcnv4_col2im_cuda(
cudaStream_t stream,
const scalar_t *value, // B, H * W, (G * D)
const scalar_t *p_offset, // B, H * W, (G*K*3)
const scalar_t *grad_output, // B, H_out*W_out, G * D
const int kernel_h, const int kernel_w, const int stride_h,
const int stride_w, const int pad_h, const int pad_w, const int dilation_h,
const int dilation_w, const int G, const int D, const int B,
const int height_in, const int width_in, const int height_out,
const int width_out, const opmath_t offset_scale, const int remove_center,
opmath_t *grad_im, opmath_t *grad_offset, const int d_stride,
const int block_thread, const bool softmax, const int padded_offset_dim) {
assert(D % d_stride == 0);
const int size_scalar = sizeof(scalar_t);
if (size_scalar == 2) {
switch (d_stride) {
case 1:
_dcnv4_col2im_cuda<scalar_t, scalar_t, 1>(
stream, value, p_offset, grad_output, kernel_h, kernel_w, stride_h,
stride_w, pad_h, pad_w, dilation_h, dilation_w, G, D, B, height_in,
width_in, height_out, width_out, offset_scale, remove_center, grad_im,
grad_offset, block_thread, softmax, padded_offset_dim);
break;
case 2:
_dcnv4_col2im_cuda<scalar_t, uint, 2>(
stream, value, p_offset, grad_output, kernel_h, kernel_w, stride_h,
stride_w, pad_h, pad_w, dilation_h, dilation_w, G, D, B, height_in,
width_in, height_out, width_out, offset_scale, remove_center, grad_im,
grad_offset, block_thread, softmax, padded_offset_dim);
break;
case 4:
_dcnv4_col2im_cuda<scalar_t, uint2, 4>(
stream, value, p_offset, grad_output, kernel_h, kernel_w, stride_h,
stride_w, pad_h, pad_w, dilation_h, dilation_w, G, D, B, height_in,
width_in, height_out, width_out, offset_scale, remove_center, grad_im,
grad_offset, block_thread, softmax, padded_offset_dim);
break;
case 8:
_dcnv4_col2im_cuda<scalar_t, uint4, 8>(
stream, value, p_offset, grad_output, kernel_h, kernel_w, stride_h,
stride_w, pad_h, pad_w, dilation_h, dilation_w, G, D, B, height_in,
width_in, height_out, width_out, offset_scale, remove_center, grad_im,
grad_offset, block_thread, softmax, padded_offset_dim);
break;
case 16:
_dcnv4_col2im_cuda<scalar_t, ulonglong4, 16>(
stream, value, p_offset, grad_output, kernel_h, kernel_w, stride_h,
stride_w, pad_h, pad_w, dilation_h, dilation_w, G, D, B, height_in,
width_in, height_out, width_out, offset_scale, remove_center, grad_im,
grad_offset, block_thread, softmax, padded_offset_dim);
break;
}
} else {
assert(size_scalar == 4);
switch (d_stride) {
case 1:
_dcnv4_col2im_cuda<scalar_t, uint, 1>(
stream, value, p_offset, grad_output, kernel_h, kernel_w, stride_h,
stride_w, pad_h, pad_w, dilation_h, dilation_w, G, D, B, height_in,
width_in, height_out, width_out, offset_scale, remove_center, grad_im,
grad_offset, block_thread, softmax, padded_offset_dim);
break;
case 2:
_dcnv4_col2im_cuda<scalar_t, uint2, 2>(
stream, value, p_offset, grad_output, kernel_h, kernel_w, stride_h,
stride_w, pad_h, pad_w, dilation_h, dilation_w, G, D, B, height_in,
width_in, height_out, width_out, offset_scale, remove_center, grad_im,
grad_offset, block_thread, softmax, padded_offset_dim);
break;
case 4:
_dcnv4_col2im_cuda<scalar_t, uint4, 4>(
stream, value, p_offset, grad_output, kernel_h, kernel_w, stride_h,
stride_w, pad_h, pad_w, dilation_h, dilation_w, G, D, B, height_in,
width_in, height_out, width_out, offset_scale, remove_center, grad_im,
grad_offset, block_thread, softmax, padded_offset_dim);
break;
case 8:
_dcnv4_col2im_cuda<scalar_t, ulonglong4, 8>(
stream, value, p_offset, grad_output, kernel_h, kernel_w, stride_h,
stride_w, pad_h, pad_w, dilation_h, dilation_w, G, D, B, height_in,
width_in, height_out, width_out, offset_scale, remove_center, grad_im,
grad_offset, block_thread, softmax, padded_offset_dim);
break;
}
}
}
\ No newline at end of file
/*!
**************************************************************************************************
* Deformable DETR
* Copyright (c) 2020 SenseTime. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
**************************************************************************************************
* Modified from
*https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
**************************************************************************************************
*/
#include "cuda/dcnv4_im2col_cuda.cuh"
#include "cuda/dcnv4_col2im_cuda.cuh"
#include <vector>
#include <ATen/ATen.h>
#include <ATen/OpMathType.h>
#include <ATen/cuda/CUDAContext.h>
#include <cooperative_groups.h>
#include <cooperative_groups/memcpy_async.h>
#include <cuda.h>
#include <cuda_fp16.h>
#include <cuda_runtime.h>
#include <torch/torch.h>
at::Tensor dcnv4_cuda_forward(
const at::Tensor &value,
const at::Tensor &p_offset,
const int kernel_h, const int kernel_w, const int stride_h,
const int stride_w, const int pad_h, const int pad_w, const int dilation_h,
const int dilation_w, const int group, const int group_channels,
const float offset_scale, const int im2col_step, const int remove_center,
const int d_stride, const int block_thread, const bool softmax) {
AT_ASSERTM(value.is_contiguous(), "input tensor has to be contiguous");
AT_ASSERTM(value.type().is_cuda(), "input must be a CUDA tensor");
AT_ASSERTM(p_offset.is_contiguous(), "input tensor has to be contiguous");
AT_ASSERTM(p_offset.type().is_cuda(), "input must be a CUDA tensor");
const int batch = value.size(0);
const int height_in = value.size(1);
const int width_in = value.size(2);
const int channels = value.size(3);
const int padded_offset_dim = p_offset.size(3);
// tensor core requirement
assert(padded_offset_dim % 8 == 0);
const int height_out =
(height_in + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h +
1;
const int width_out =
(width_in + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1;
const int im2col_step_ = std::min(batch, im2col_step);
AT_ASSERTM(batch % im2col_step_ == 0, "batch(", batch,
") must divide im2col_step(", im2col_step_, ")");
AT_ASSERTM(
channels == (group * group_channels),
"Input channels and group times group channels wont match: (%d vs %d).",
channels, group * group_channels);
auto output = at::zeros(
{batch, height_out, width_out, group * group_channels}, value.options());
const int batch_n = im2col_step_;
auto output_n = output.view({batch / batch_n, batch_n, height_out, width_out,
group * group_channels});
auto per_value_size = height_in * width_in * channels;
auto per_offset_size = height_out * width_out * padded_offset_dim;
for (int n = 0; n < batch / im2col_step_; ++n) {
auto columns = output_n.select(0, n);
AT_DISPATCH_FLOATING_TYPES_AND2(
at::ScalarType::Half, at::ScalarType::BFloat16, value.scalar_type(),
"dcnv4_forward_cuda", ([&] {
dcnv4_im2col_cuda(
at::cuda::getCurrentCUDAStream(),
value.data_ptr<scalar_t>() + n * im2col_step_ * per_value_size,
p_offset.data_ptr<scalar_t>() +
n * im2col_step_ * per_offset_size,
columns.data_ptr<scalar_t>(), kernel_h, kernel_w, stride_h,
stride_w, pad_h, pad_w, dilation_h, dilation_w, group,
group_channels, batch_n, height_in, width_in, height_out,
width_out, offset_scale, remove_center, d_stride, block_thread,
softmax, padded_offset_dim);
}));
}
return output;
}
std::vector<at::Tensor>
dcnv4_cuda_backward(
const at::Tensor &value,
const at::Tensor &p_offset,
const int kernel_h, const int kernel_w, const int stride_h,
const int stride_w, const int pad_h, const int pad_w, const int dilation_h,
const int dilation_w, const int group, const int group_channels,
const float offset_scale, const int im2col_step, const at::Tensor &grad_output,
const int remove_center, const int d_stride, const int block_thread,
const bool softmax) {
AT_ASSERTM(value.is_contiguous(), "input tensor has to be contiguous");
AT_ASSERTM(p_offset.is_contiguous(), "offset tensor has to be contiguous");
AT_ASSERTM(grad_output.is_contiguous(),
"grad_output tensor has to be contiguous");
AT_ASSERTM(value.type().is_cuda(), "input must be a CUDA tensor");
AT_ASSERTM(p_offset.type().is_cuda(), "offset must be a CUDA tensor");
AT_ASSERTM(grad_output.type().is_cuda(),
"grad_output must be a CUDA tensor");
const int batch = value.size(0);
const int height_in = value.size(1);
const int width_in = value.size(2);
const int channels = value.size(3);
const int padded_offset_dim = p_offset.size(3);
assert(padded_offset_dim % 8 == 0);
const int height_out =
(height_in + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h +
1;
const int width_out =
(width_in + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1;
const int im2col_step_ = std::min(batch, im2col_step);
AT_ASSERTM(batch % im2col_step_ == 0, "batch(", batch,
") must divide im2col_step(", im2col_step_, ")");
AT_ASSERTM(
channels == (group * group_channels),
"Input channels and group times group channels wont match: (%d vs %d).",
channels, group * group_channels);
auto dtype = value.dtype();
if (dtype == at::kHalf){
dtype = at::kFloat;
}
auto grad_input = at::zeros_like(value, dtype);
auto grad_offset = at::zeros_like(p_offset, dtype);
const int batch_n = im2col_step_;
auto grad_output_n = grad_output.view({batch / batch_n, batch_n, height_out, width_out,
group, group_channels});
auto per_value_size = height_in * width_in * channels;
auto per_offset_size = height_out * width_out * padded_offset_dim;
for (int n = 0; n < batch / im2col_step_; ++n) {
auto columns = grad_output_n.select(0, n);
AT_DISPATCH_FLOATING_TYPES_AND2(
at::ScalarType::Half, at::ScalarType::BFloat16, value.scalar_type(),
"dcnv4_backward_cuda", ([&] {
dcnv4_col2im_cuda(
at::cuda::getCurrentCUDAStream(),
value.data_ptr<scalar_t>() + n * im2col_step_ * per_value_size,
p_offset.data_ptr<scalar_t>() +
n * im2col_step_ * per_offset_size,
columns.data_ptr<scalar_t>(), kernel_h, kernel_w, stride_h,
stride_w, pad_h, pad_w, dilation_h, dilation_w, group,
group_channels, batch_n, height_in, width_in, height_out,
width_out, offset_scale, remove_center,
grad_input.data<opmath_t>() + n * im2col_step_ * per_value_size,
grad_offset.data<opmath_t>() +
n * im2col_step_ * per_offset_size,
d_stride, block_thread, softmax, padded_offset_dim
);
}));
}
if(value.dtype() == torch::kHalf){
return {grad_input.to(torch::kHalf), grad_offset.to(torch::kHalf)};
}
else{
return {grad_input, grad_offset};
}
}
/*!
**************************************************************************************************
* Deformable DETR
* Copyright (c) 2020 SenseTime. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
**************************************************************************************************
* Modified from
*https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
**************************************************************************************************
*/
#pragma once
#include <torch/extension.h>
at::Tensor dcnv4_cuda_forward(
const at::Tensor &value,
const at::Tensor &p_offset,
const int kernel_h, const int kernel_w, const int stride_h,
const int stride_w, const int pad_h, const int pad_w, const int dilation_h,
const int dilation_w, const int group, const int group_channels,
const float offset_scale, const int im2col_step, const int remove_center,
const int d_stride, const int block_thread, const bool softmax);
std::vector<at::Tensor>
dcnv4_cuda_backward(
const at::Tensor &value,
const at::Tensor &p_offset,
const int kernel_h, const int kernel_w, const int stride_h,
const int stride_w, const int pad_h, const int pad_w, const int dilation_h,
const int dilation_w, const int group, const int group_channels,
const float offset_scale, const int im2col_step, const at::Tensor &grad_output,
const int remove_center, const int d_stride, const int block_thread,
const bool softmax);
\ No newline at end of file
#include <algorithm>
#include <cstdio>
#include <cstring>
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <THC/THCAtomics.cuh>
#include <ATen/ATen.h>
#include <ATen/OpMathType.h>
#include <ATen/cuda/CUDAContext.h>
#include <cooperative_groups.h>
#include <cooperative_groups/memcpy_async.h>
#include <cuda.h>
#include <cuda_fp16.h>
#include <cuda_runtime.h>
#include "common.h"
template <typename scalar_t, int d_stride, typename transfer_t, int L, int K,
bool softmax>
__global__ void forward_kernel_dcn(
const scalar_t *p_value, const scalar_t *p_offset, scalar_t *p_output,
const int G, const int D, const int Q, const int kernel_h,
const int kernel_w, const int stride_h, const int stride_w, const int pad_h,
const int pad_w, const int dilation_h, const int dilation_w,
const int height_in, const int width_in, const int height_out,
const int width_out, const opmath_t offset_scale, const int remove_center,
const int block_multiplier, const int padded_offset_dim) {
const int &qi = (blockIdx.x * block_multiplier % Q) + threadIdx.z;
const int &bi = blockIdx.x * block_multiplier / Q;
const int &di_s = threadIdx.x * d_stride;
const int &gi = threadIdx.y;
constexpr int li = 0;
extern __shared__ char _s[];
opmath_t *const p_mask_shm =
(opmath_t *)(_s) + ((threadIdx.z * G + gi) * L + li) * K;
opmath_t p_out_shm[d_stride] = {0.};
const scalar_t *p_offset_ptr = p_offset + (bi*Q + qi)*padded_offset_dim + gi*K*3;
const int mask_length = K;
const int num_thread = (D / d_stride);
const int num_iter = mask_length / num_thread;
const int remainder = mask_length - num_iter * num_thread;
for (int i = 0; i < num_iter; i++) {
*(p_mask_shm + num_thread * i + threadIdx.x) =
*(scalar_t *)(p_offset_ptr + K * 2 + num_thread * i + threadIdx.x);
}
if (remainder > 0 && threadIdx.x < remainder) {
*(p_mask_shm + num_thread * num_iter + threadIdx.x) = *(
scalar_t *)(p_offset_ptr + K * 2 + num_thread * num_iter + threadIdx.x);
}
int mask_idx;
if (softmax) {
__syncthreads();
// Calculate softmax over L and K
if (threadIdx.x == 0) { // gi != 0, di = 0, li = 0
opmath_t softmax_max = -1e100;
opmath_t softmax_sum = 0.0;
// get max
// #pragma unroll
for (int j = 0; j < K; j++) {
softmax_max = max(softmax_max, p_mask_shm[j]);
}
// get sumexp
// #pragma unroll
for (int j = 0; j < K; j++) {
opmath_t exp_results = exp(p_mask_shm[j] - softmax_max);
p_mask_shm[j] = exp_results;
softmax_sum += exp_results;
}
// normalize
// #pragma unroll
for (int j = 0; j < K; j++) {
p_mask_shm[j] /= softmax_sum;
}
}
__syncthreads();
}
int offset_idx = 0;
mask_idx = 0;
const int w_stride = G * D;
const int base_ptr = gi * D + di_s;
const scalar_t *p_value_ptr =
p_value + (bi * (height_in * width_in)) * (G * D);
const int p0_w = ((dilation_w * (kernel_w - 1)) >> 1) - pad_w +
(qi % width_out) * stride_w;
const int p0_h = ((dilation_h * (kernel_h - 1)) >> 1) - pad_h +
(qi / width_out) * stride_h;
const opmath_t p0_w_ =
p0_w - ((dilation_w * (kernel_w - 1)) >> 1) * offset_scale;
const opmath_t p0_h_ =
p0_h - ((dilation_h * (kernel_h - 1)) >> 1) * offset_scale;
const int center_h = kernel_h / 2;
const int center_w = kernel_w / 2;
int out_idx = ((bi * Q + qi) * G + gi) * D + di_s;
for (int i = 0; i < kernel_w; ++i) {
for (int j = 0; j < kernel_h; ++j) {
if (i != center_w || j != center_h || !remove_center) {
const opmath_t w_im =
p0_w_ + (i * dilation_w + (opmath_t)p_offset_ptr[offset_idx]) *
offset_scale;
const opmath_t h_im =
p0_h_ + (j * dilation_h + (opmath_t)p_offset_ptr[offset_idx + 1]) *
offset_scale;
const opmath_t attn = p_mask_shm[mask_idx];
if (h_im > -1 && w_im > -1 && h_im < height_in && w_im < width_in) {
ms_deform_attn_im2col_bilinear<scalar_t, transfer_t, d_stride>(
p_out_shm, p_value_ptr, height_in, width_in, h_im, w_im, attn,
w_stride, base_ptr);
}
offset_idx += 2;
mask_idx += 1;
}
}
}
scalar_t *fp16_regs = (scalar_t *)(p_out_shm);
#pragma unroll
for (int ds = 0; ds < d_stride; ds++) {
fp16_regs[ds] = p_out_shm[ds];
}
*(transfer_t *)(p_output + out_idx) = *(transfer_t *)(p_out_shm);
}
template <typename scalar_t, int d_stride, typename transfer_t, int L, int K,
bool softmax>
__global__ void forward_kernel_dcn_reg(
const scalar_t *p_value, const scalar_t *p_offset, scalar_t *p_output,
const int G, const int D, const int Q, const int kernel_h,
const int kernel_w, const int stride_h, const int stride_w, const int pad_h,
const int pad_w, const int dilation_h, const int dilation_w,
const int height_in, const int width_in, const int height_out,
const int width_out, const opmath_t offset_scale, const int remove_center,
const int block_multiplier, const int padded_offset_dim) {
const int &qi = (blockIdx.x * block_multiplier % Q) + threadIdx.z;
const int &bi = blockIdx.x * block_multiplier / Q;
const int &di_s = threadIdx.x * d_stride;
const int &gi = threadIdx.y;
constexpr int li = 0;
opmath_t p_mask_shm[K] = {0.};
opmath_t p_out_shm[d_stride] = {0.};
const scalar_t *p_offset_ptr = p_offset + (bi*Q + qi)*padded_offset_dim + gi*K*3;
const int mask_length = K;
const int num_thread = (D / d_stride);
const int num_iter = mask_length / num_thread;
const int remainder = mask_length - num_iter * num_thread;
for (int i=0; i < K; i++){
p_mask_shm[i] = *(p_offset_ptr + K*2 + i);
}
if (softmax) {
// Calculate softmax over L and K
opmath_t softmax_max = -1e100;
opmath_t softmax_sum = 0.0;
// get max
for (int j = 0; j < K; j++) {
softmax_max = max(softmax_max, p_mask_shm[j]);
}
// get sumexp
for (int j = 0; j < K; j++) {
opmath_t exp_results = exp(p_mask_shm[j] - softmax_max);
p_mask_shm[j] = exp_results;
softmax_sum += exp_results;
}
// normalize
for (int j = 0; j < K; j++) {
p_mask_shm[j] /= softmax_sum;
}
}
int offset_idx = 0;
int mask_idx = 0;
const int w_stride = G * D;
const int base_ptr = gi * D + di_s;
const scalar_t *p_value_ptr =
p_value + (bi * (height_in * width_in)) * (G * D);
const int p0_w = ((dilation_w * (kernel_w - 1)) >> 1) - pad_w +
(qi % width_out) * stride_w;
const int p0_h = ((dilation_h * (kernel_h - 1)) >> 1) - pad_h +
(qi / width_out) * stride_h;
const opmath_t p0_w_ =
p0_w - ((dilation_w * (kernel_w - 1)) >> 1) * offset_scale;
const opmath_t p0_h_ =
p0_h - ((dilation_h * (kernel_h - 1)) >> 1) * offset_scale;
const int center_h = kernel_h / 2;
const int center_w = kernel_w / 2;
int out_idx = ((bi * Q + qi) * G + gi) * D + di_s;
for (int i = 0; i < kernel_w; ++i) {
for (int j = 0; j < kernel_h; ++j) {
if (i != center_w || j != center_h || !remove_center) {
const opmath_t w_im =
p0_w_ + (i * dilation_w + (opmath_t)p_offset_ptr[offset_idx]) *
offset_scale;
const opmath_t h_im =
p0_h_ + (j * dilation_h + (opmath_t)p_offset_ptr[offset_idx + 1]) *
offset_scale;
const opmath_t attn = p_mask_shm[mask_idx];
if (h_im > -1 && w_im > -1 && h_im < height_in && w_im < width_in) {
ms_deform_attn_im2col_bilinear<scalar_t, transfer_t, d_stride>(
p_out_shm, p_value_ptr, height_in, width_in, h_im, w_im, attn,
w_stride, base_ptr);
}
offset_idx += 2;
mask_idx += 1;
}
}
}
scalar_t *fp16_regs = (scalar_t *)(p_out_shm);
#pragma unroll
for (int ds = 0; ds < d_stride; ds++) {
fp16_regs[ds] = p_out_shm[ds];
}
*(transfer_t *)(p_output + out_idx) = *(transfer_t *)(p_out_shm);
}
template <typename scalar_t, typename stride_type, int d_stride>
void _dcnv4_im2col_cuda(cudaStream_t stream,
const scalar_t *value, // B, H * W, (G * D)
const scalar_t *p_offset, // B, H * W, G * K * 3)
scalar_t *output, // B, H_out*W_out, G * D
const int kernel_h, const int kernel_w,
const int stride_h, const int stride_w,
const int pad_h, const int pad_w,
const int dilation_h, const int dilation_w,
const int G, const int D, const int B,
const int height_in, const int width_in,
const int height_out, const int width_out,
const opmath_t offset_scale,
const int remove_center, const int block_thread,
const int softmax,
const int padded_offset_dim) {
constexpr int L = 1;
auto kernel = forward_kernel_dcn_reg<scalar_t, d_stride, stride_type, 1, 9, true>;
int N = height_in * width_in;
int Q = height_out * width_out;
int K = kernel_h * kernel_w;
if (remove_center) {
K -= 1;
}
if (softmax) {
switch (K) {
case 9:
kernel = forward_kernel_dcn_reg<scalar_t, d_stride, stride_type, 1, 9, true>;
break;
case 8:
kernel = forward_kernel_dcn_reg<scalar_t, d_stride, stride_type, 1, 8, true>;
break;
default:
printf("K=%ld\n", K);
throw std::invalid_argument("invalid kernel shape");
}
} else {
switch (K) {
case 9:
kernel = forward_kernel_dcn_reg<scalar_t, d_stride, stride_type, 1, 9, false>;
break;
case 8:
kernel = forward_kernel_dcn_reg<scalar_t, d_stride, stride_type, 1, 8, false>;
break;
default:
printf("K=%ld\n", K);
throw std::invalid_argument("invalid kernel shape");
}
}
const int block_multiplier = block_thread / (D / d_stride) / G;
assert((B*Q) % block_multiplier == 0);
dim3 num_blocks(B*Q / block_multiplier);
dim3 num_threads(D / d_stride, G, block_multiplier);
int shm_size = 0;
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize,
shm_size);
kernel<<<num_blocks, num_threads, shm_size, stream>>>(
value, p_offset, output, G, D, Q, kernel_h, kernel_w, stride_h, stride_w,
pad_h, pad_w, dilation_h, dilation_w, height_in, width_in, height_out,
width_out, offset_scale, remove_center, block_multiplier, padded_offset_dim);
cudaError_t err = cudaGetLastError();
if (err != cudaSuccess) {
printf("error in dcnv4_im2col_cuda: %s\n", cudaGetErrorString(err));
printf("launch arguments: gridDim=(%d, %d, %d), blockDim=(%d, %d, %d), "
"shm_size=%d\n\n",
num_blocks.x, num_blocks.y, num_blocks.z, num_threads.x,
num_threads.y, num_threads.z, shm_size);
AT_ASSERTM(false, "kernel launch error");
}
}
template <typename scalar_t>
void dcnv4_im2col_cuda(
cudaStream_t stream,
const scalar_t *value, // B, H * W, (G * D)
const scalar_t *p_offset, // B, H * W, G * K * 3)
scalar_t *output, // B, H_out*W_out, G * D
const int kernel_h, const int kernel_w, const int stride_h,
const int stride_w, const int pad_h, const int pad_w, const int dilation_h,
const int dilation_w, const int G, const int D, const int B,
const int height_in, const int width_in, const int height_out,
const int width_out, const opmath_t offset_scale, const int remove_center,
const int d_stride, const int block_thread, const bool softmax,
const int padded_offset_dim) {
assert(D % d_stride == 0);
if (sizeof(scalar_t) == 2) {
switch (d_stride) {
case 1:
_dcnv4_im2col_cuda<scalar_t, scalar_t, 1>(
stream, value, p_offset, output, kernel_h, kernel_w, stride_h,
stride_w, pad_h, pad_w, dilation_h, dilation_w, G, D, B, height_in,
width_in, height_out, width_out, offset_scale, remove_center,
block_thread, softmax, padded_offset_dim);
break;
case 2:
_dcnv4_im2col_cuda<scalar_t, uint, 2>(
stream, value, p_offset, output, kernel_h, kernel_w, stride_h,
stride_w, pad_h, pad_w, dilation_h, dilation_w, G, D, B, height_in,
width_in, height_out, width_out, offset_scale, remove_center,
block_thread, softmax, padded_offset_dim);
break;
case 4:
_dcnv4_im2col_cuda<scalar_t, uint2, 4>(
stream, value, p_offset, output, kernel_h, kernel_w, stride_h,
stride_w, pad_h, pad_w, dilation_h, dilation_w, G, D, B, height_in,
width_in, height_out, width_out, offset_scale, remove_center,
block_thread, softmax, padded_offset_dim);
break;
case 8:
_dcnv4_im2col_cuda<scalar_t, uint4, 8>(
stream, value, p_offset, output, kernel_h, kernel_w, stride_h,
stride_w, pad_h, pad_w, dilation_h, dilation_w, G, D, B, height_in,
width_in, height_out, width_out, offset_scale, remove_center,
block_thread, softmax, padded_offset_dim);
break;
case 16:
_dcnv4_im2col_cuda<scalar_t, ulonglong4, 16>(
stream, value, p_offset, output, kernel_h, kernel_w, stride_h,
stride_w, pad_h, pad_w, dilation_h, dilation_w, G, D, B, height_in,
width_in, height_out, width_out, offset_scale, remove_center,
block_thread, softmax, padded_offset_dim);
break;
}
} else {
assert(sizeof(scalar_t) == 4);
switch (d_stride) {
case 1:
_dcnv4_im2col_cuda<scalar_t, uint, 1>(
stream, value, p_offset, output, kernel_h, kernel_w, stride_h,
stride_w, pad_h, pad_w, dilation_h, dilation_w, G, D, B, height_in,
width_in, height_out, width_out, offset_scale, remove_center,
block_thread, softmax, padded_offset_dim);
break;
case 2:
_dcnv4_im2col_cuda<scalar_t, uint2, 2>(
stream, value, p_offset, output, kernel_h, kernel_w, stride_h,
stride_w, pad_h, pad_w, dilation_h, dilation_w, G, D, B, height_in,
width_in, height_out, width_out, offset_scale, remove_center,
block_thread, softmax, padded_offset_dim);
break;
case 4:
_dcnv4_im2col_cuda<scalar_t, uint4, 4>(
stream, value, p_offset, output, kernel_h, kernel_w, stride_h,
stride_w, pad_h, pad_w, dilation_h, dilation_w, G, D, B, height_in,
width_in, height_out, width_out, offset_scale, remove_center,
block_thread, softmax, padded_offset_dim);
break;
case 8:
_dcnv4_im2col_cuda<scalar_t, ulonglong4, 8>(
stream, value, p_offset, output, kernel_h, kernel_w, stride_h,
stride_w, pad_h, pad_w, dilation_h, dilation_w, G, D, B, height_in,
width_in, height_out, width_out, offset_scale, remove_center,
block_thread, softmax, padded_offset_dim);
break;
default:
printf("not supported for d_stride > 8 for fp32");
throw std::invalid_argument("invalid d_stride");
}
}
}
\ No newline at end of file
/*!
**************************************************************************************************
* Deformable DETR
* Copyright (c) 2020 SenseTime. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
**************************************************************************************************
* Modified from
*https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
**************************************************************************************************
*/
#include "cuda/flash_deform_im2col_cuda.cuh"
#include "cuda/flash_deform_col2im_cuda.cuh"
#include <vector>
#include <ATen/ATen.h>
#include <ATen/OpMathType.h>
#include <ATen/cuda/CUDAContext.h>
#include <cooperative_groups.h>
#include <cooperative_groups/memcpy_async.h>
#include <cuda.h>
#include <cuda_fp16.h>
#include <cuda_runtime.h>
#include <torch/torch.h>
at::Tensor flash_deform_attn_cuda_forward(
const at::Tensor &value, const at::Tensor &spatial_shapes,
const at::Tensor &level_start_index, const at::Tensor &sampling_loc_attn,
const int im2col_step = 64, const int K=8, const int d_stride=8,
const int block_thread=0) {
AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous");
AT_ASSERTM(spatial_shapes.is_contiguous(),
"spatial_shapes tensor has to be contiguous");
AT_ASSERTM(level_start_index.is_contiguous(),
"level_start_index tensor has to be contiguous");
AT_ASSERTM(sampling_loc_attn.is_contiguous(),
"sampling_loc_attn tensor has to be contiguous");
AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor");
AT_ASSERTM(spatial_shapes.type().is_cuda(),
"spatial_shapes must be a CUDA tensor");
AT_ASSERTM(level_start_index.type().is_cuda(),
"level_start_index must be a CUDA tensor");
AT_ASSERTM(sampling_loc_attn.type().is_cuda(),
"sampling_loc_attn must be a CUDA tensor");
const int batch = value.size(0);
const int spatial_size = value.size(1);
const int num_heads = value.size(2);
const int num_channels = value.size(3);
const int num_levels = spatial_shapes.size(0);
const int num_query = sampling_loc_attn.size(1);
const int num_point = K;
const int im2col_step_ = std::min(batch, im2col_step);
AT_ASSERTM(batch % im2col_step_ == 0, "batch(", batch,
") must divide im2col_step(", im2col_step_, ")");
auto output =
at::zeros({batch, num_query, num_heads, num_channels}, value.options());
auto per_value_size = spatial_size * num_heads * num_channels;
auto per_offset_size = num_query * num_heads * num_levels * num_point * 3;
auto per_out_size = num_query * num_heads * num_channels;
for (int n = 0; n < batch / im2col_step_; ++n) {
AT_DISPATCH_FLOATING_TYPES_AND2(
at::ScalarType::Half, at::ScalarType::BFloat16, value.scalar_type(),
"flash_deform_attn_forward_cuda", ([&] {
flash_deformable_im2col_cuda(
at::cuda::getCurrentCUDAStream(),
value.data_ptr<scalar_t>() + n * im2col_step_ * per_value_size,
spatial_shapes.data<int64_t>(), level_start_index.data<int64_t>(),
sampling_loc_attn.data_ptr<scalar_t>() +
n * im2col_step_ * per_offset_size,
output.data_ptr<scalar_t>() + n * im2col_step_ * per_out_size,
im2col_step_, spatial_size, num_heads, num_channels, num_levels,
num_query, num_point, d_stride, block_thread, true);
}));
}
output = output.view({batch, num_query, num_heads * num_channels});
return output;
}
std::vector<at::Tensor>
flash_deform_attn_cuda_backward(
const at::Tensor &value, const at::Tensor &spatial_shapes,
const at::Tensor &level_start_index, const at::Tensor &sampling_loc_attn,
const at::Tensor &grad_output, const int im2col_step = 64, const int K=8,
const int d_stride=2, const int block_thread=0) {
AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous");
AT_ASSERTM(spatial_shapes.is_contiguous(),
"spatial_shapes tensor has to be contiguous");
AT_ASSERTM(level_start_index.is_contiguous(),
"level_start_index tensor has to be contiguous");
AT_ASSERTM(sampling_loc_attn.is_contiguous(),
"sampling_loc_attn tensor has to be contiguous");
AT_ASSERTM(grad_output.is_contiguous(),
"grad_output tensor has to be contiguous");
AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor");
AT_ASSERTM(spatial_shapes.type().is_cuda(),
"spatial_shapes must be a CUDA tensor");
AT_ASSERTM(level_start_index.type().is_cuda(),
"level_start_index must be a CUDA tensor");
AT_ASSERTM(sampling_loc_attn.type().is_cuda(),
"sampling_loc_attn must be a CUDA tensor");
AT_ASSERTM(grad_output.type().is_cuda(),
"grad_output must be a CUDA tensor");
const int batch = value.size(0);
const int spatial_size = value.size(1);
const int num_heads = value.size(2);
const int num_channels = value.size(3);
const int num_levels = spatial_shapes.size(0);
const int num_query = sampling_loc_attn.size(1);
const int num_point = K;
const int im2col_step_ = std::min(batch, im2col_step);
AT_ASSERTM(batch % im2col_step_ == 0, "batch(", batch,
") must divide im2col_step(", im2col_step_, ")");
auto dtype = value.dtype();
if (dtype == at::kHalf){
dtype = at::kFloat;
}
auto grad_input = at::zeros_like(value, dtype);
auto grad_offset = at::zeros_like(sampling_loc_attn, dtype);
auto per_value_size = spatial_size * num_heads * num_channels;
auto per_offset_size = num_query * num_heads * num_levels * num_point * 3;
auto per_out_size = num_query * num_heads * num_channels;
for (int n = 0; n < batch / im2col_step_; ++n) {
AT_DISPATCH_FLOATING_TYPES_AND2(
at::ScalarType::Half, at::ScalarType::BFloat16, value.scalar_type(),
"flash_deform_attn_backward_cuda", ([&] {
flash_deformable_col2im_cuda(
at::cuda::getCurrentCUDAStream(),
value.data_ptr<scalar_t>() + n * im2col_step_ * per_value_size,
spatial_shapes.data<int64_t>(), level_start_index.data<int64_t>(),
sampling_loc_attn.data_ptr<scalar_t>() +
n * im2col_step_ * per_offset_size,
grad_output.data_ptr<scalar_t>() + n * im2col_step_ * per_out_size,
im2col_step_, spatial_size, num_heads, num_channels, num_levels,
num_query, num_point,
grad_input.data<opmath_t>() + n * im2col_step_ * per_value_size,
grad_offset.data<opmath_t>() + n * im2col_step_ * per_offset_size,
d_stride, block_thread
);
}));
}
if(value.dtype() == torch::kHalf){
return {grad_input.to(torch::kHalf), grad_offset.to(torch::kHalf)};
}
else{
return {grad_input, grad_offset};
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment