Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
zk
GroundingDINO-DCU-Optimized
Commits
34e4011b
Commit
34e4011b
authored
Apr 14, 2026
by
zk
Browse files
首次提交
parents
Pipeline
#3503
failed with stages
in 0 seconds
Changes
150
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
7803 additions
and
0 deletions
+7803
-0
groundingdino/models/GroundingDINO/csrc/MsDeformAttn/ms_deform_attn_cuda.h
...els/GroundingDINO/csrc/MsDeformAttn/ms_deform_attn_cuda.h
+34
-0
groundingdino/models/GroundingDINO/csrc/MsDeformAttn/ms_deform_attn_cuda.hip
...s/GroundingDINO/csrc/MsDeformAttn/ms_deform_attn_cuda.hip
+189
-0
groundingdino/models/GroundingDINO/csrc/MsDeformAttn/ms_deform_im2col_cuda.cuh
...GroundingDINO/csrc/MsDeformAttn/ms_deform_im2col_cuda.cuh
+1328
-0
groundingdino/models/GroundingDINO/csrc/MsDeformAttn/ms_deform_im2col_cuda_hip.cuh
...ndingDINO/csrc/MsDeformAttn/ms_deform_im2col_cuda_hip.cuh
+1331
-0
groundingdino/models/GroundingDINO/csrc/cuda_version.cu
groundingdino/models/GroundingDINO/csrc/cuda_version.cu
+7
-0
groundingdino/models/GroundingDINO/csrc/hip_version.hip
groundingdino/models/GroundingDINO/csrc/hip_version.hip
+9
-0
groundingdino/models/GroundingDINO/csrc/vision.cpp
groundingdino/models/GroundingDINO/csrc/vision.cpp
+59
-0
groundingdino/models/GroundingDINO/fuse_modules.py
groundingdino/models/GroundingDINO/fuse_modules.py
+297
-0
groundingdino/models/GroundingDINO/groundingdino.py
groundingdino/models/GroundingDINO/groundingdino.py
+434
-0
groundingdino/models/GroundingDINO/groundingdino_onnx.py
groundingdino/models/GroundingDINO/groundingdino_onnx.py
+434
-0
groundingdino/models/GroundingDINO/groundingdino_torch.py
groundingdino/models/GroundingDINO/groundingdino_torch.py
+434
-0
groundingdino/models/GroundingDINO/ms_deform_attn copy.py
groundingdino/models/GroundingDINO/ms_deform_attn copy.py
+413
-0
groundingdino/models/GroundingDINO/ms_deform_attn.py
groundingdino/models/GroundingDINO/ms_deform_attn.py
+440
-0
groundingdino/models/GroundingDINO/transformer copy.py
groundingdino/models/GroundingDINO/transformer copy.py
+959
-0
groundingdino/models/GroundingDINO/transformer.py
groundingdino/models/GroundingDINO/transformer.py
+959
-0
groundingdino/models/GroundingDINO/transformer_vanilla.py
groundingdino/models/GroundingDINO/transformer_vanilla.py
+123
-0
groundingdino/models/GroundingDINO/utils.py
groundingdino/models/GroundingDINO/utils.py
+268
-0
groundingdino/models/__init__.py
groundingdino/models/__init__.py
+18
-0
groundingdino/models/registry.py
groundingdino/models/registry.py
+66
-0
groundingdino/util/__init__.py
groundingdino/util/__init__.py
+1
-0
No files found.
groundingdino/models/GroundingDINO/csrc/MsDeformAttn/ms_deform_attn_cuda.h
0 → 100644
View file @
34e4011b
/*!
**************************************************************************************************
* Deformable DETR
* Copyright (c) 2020 SenseTime. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
**************************************************************************************************
* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
**************************************************************************************************
*/
#pragma once
#include <torch/extension.h>
namespace
groundingdino
{
at
::
Tensor
ms_deform_attn_cuda_forward
(
const
at
::
Tensor
&
value
,
const
at
::
Tensor
&
spatial_shapes
,
const
at
::
Tensor
&
level_start_index
,
const
at
::
Tensor
&
sampling_loc
,
const
at
::
Tensor
&
attn_weight
,
const
int
im2col_step
);
std
::
vector
<
at
::
Tensor
>
ms_deform_attn_cuda_backward
(
const
at
::
Tensor
&
value
,
const
at
::
Tensor
&
spatial_shapes
,
const
at
::
Tensor
&
level_start_index
,
const
at
::
Tensor
&
sampling_loc
,
const
at
::
Tensor
&
attn_weight
,
const
at
::
Tensor
&
grad_output
,
const
int
im2col_step
);
}
// namespace groundingdino
\ No newline at end of file
groundingdino/models/GroundingDINO/csrc/MsDeformAttn/ms_deform_attn_cuda.hip
0 → 100644
View file @
34e4011b
// !!! This is a file automatically generated by hipify!!!
#include <ATen/dtk_macros.h>
/*!
**************************************************************************************************
* Deformable DETR
* Copyright (c) 2020 SenseTime. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
**************************************************************************************************
* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
**************************************************************************************************
*/
#include <vector>
#include "ms_deform_im2col_cuda_hip.cuh"
#include <ATen/ATen.h>
#include <ATen/hip/HIPContext.h>
#include <hip/hip_runtime.h>
#include <hip/hip_runtime.h>
namespace groundingdino {
at::Tensor ms_deform_attn_cuda_forward(
const at::Tensor &value,
const at::Tensor &spatial_shapes,
const at::Tensor &level_start_index,
const at::Tensor &sampling_loc,
const at::Tensor &attn_weight,
const int im2col_step)
{
AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous");
AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous");
AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous");
AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous");
AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous");
AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor");
AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor");
AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor");
AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor");
AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor");
const int batch = value.size(0);
const int spatial_size = value.size(1);
const int num_heads = value.size(2);
const int channels = value.size(3);
const int num_levels = spatial_shapes.size(0);
const int num_query = sampling_loc.size(1);
const int num_point = sampling_loc.size(4);
const int im2col_step_ = ::min(batch, im2col_step);
AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_);
auto output = at::zeros({batch, num_query, num_heads, channels}, value.options());
const int batch_n = im2col_step_;
auto output_n = output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels});
auto per_value_size = spatial_size * num_heads * channels;
auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2;
auto per_attn_weight_size = num_query * num_heads * num_levels * num_point;
for (int n = 0; n < batch/im2col_step_; ++n)
{
auto columns = output_n.select(0, n);
AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_forward_cuda", ([&] {
ms_deformable_im2col_cuda(at::hip::getCurrentHIPStreamMasqueradingAsCUDA(),
value.data<scalar_t>() + n * im2col_step_ * per_value_size,
spatial_shapes.data<int64_t>(),
level_start_index.data<int64_t>(),
sampling_loc.data<scalar_t>() + n * im2col_step_ * per_sample_loc_size,
attn_weight.data<scalar_t>() + n * im2col_step_ * per_attn_weight_size,
batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point,
columns.data<scalar_t>());
}));
}
output = output.view({batch, num_query, num_heads*channels});
return output;
}
// at::Tensor ms_deform_attn_forward_wrapper(
// const at::Tensor &value,
// const at::Tensor &spatial_shapes,
// const at::Tensor &level_start_index,
// const at::Tensor &sampling_loc,
// const at::Tensor &attn_weight,
// int64_t im2col_step // ✅ 注意这里
// )
// {
// return groundingdino::ms_deform_attn_cuda_forward(
// value,
// spatial_shapes,
// level_start_index,
// sampling_loc,
// attn_weight,
// im2col_step
// );
// }
std::vector<at::Tensor> ms_deform_attn_cuda_backward(
const at::Tensor &value,
const at::Tensor &spatial_shapes,
const at::Tensor &level_start_index,
const at::Tensor &sampling_loc,
const at::Tensor &attn_weight,
const at::Tensor &grad_output,
const int im2col_step)
{
AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous");
AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous");
AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous");
AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous");
AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous");
AT_ASSERTM(grad_output.is_contiguous(), "grad_output tensor has to be contiguous");
AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor");
AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor");
AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor");
AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor");
AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor");
AT_ASSERTM(grad_output.type().is_cuda(), "grad_output must be a CUDA tensor");
const int batch = value.size(0);
const int spatial_size = value.size(1);
const int num_heads = value.size(2);
const int channels = value.size(3);
const int num_levels = spatial_shapes.size(0);
const int num_query = sampling_loc.size(1);
const int num_point = sampling_loc.size(4);
const int im2col_step_ = ::min(batch, im2col_step);
AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_);
auto grad_value = at::zeros_like(value);
auto grad_sampling_loc = at::zeros_like(sampling_loc);
auto grad_attn_weight = at::zeros_like(attn_weight);
const int batch_n = im2col_step_;
auto per_value_size = spatial_size * num_heads * channels;
auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2;
auto per_attn_weight_size = num_query * num_heads * num_levels * num_point;
auto grad_output_n = grad_output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels});
for (int n = 0; n < batch/im2col_step_; ++n)
{
auto grad_output_g = grad_output_n.select(0, n);
AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_backward_cuda", ([&] {
ms_deformable_col2im_cuda(at::hip::getCurrentHIPStreamMasqueradingAsCUDA(),
grad_output_g.data<scalar_t>(),
value.data<scalar_t>() + n * im2col_step_ * per_value_size,
spatial_shapes.data<int64_t>(),
level_start_index.data<int64_t>(),
sampling_loc.data<scalar_t>() + n * im2col_step_ * per_sample_loc_size,
attn_weight.data<scalar_t>() + n * im2col_step_ * per_attn_weight_size,
batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point,
grad_value.data<scalar_t>() + n * im2col_step_ * per_value_size,
grad_sampling_loc.data<scalar_t>() + n * im2col_step_ * per_sample_loc_size,
grad_attn_weight.data<scalar_t>() + n * im2col_step_ * per_attn_weight_size);
}));
}
return {
grad_value, grad_sampling_loc, grad_attn_weight
};
}
} // namespace groundingdino
// #include <torch/library.h>
// // 注册 schema
// TORCH_LIBRARY(my_ops, m) {
// m.def("ms_deform_attn(Tensor value, Tensor spatial_shapes, Tensor level_start_index, Tensor sampling_loc, Tensor attn_weight, int im2col_step) -> Tensor");
// }
// // CUDA实现
// TORCH_LIBRARY_IMPL(my_ops, CUDA, m) {
// m.impl("ms_deform_attn", groundingdino::ms_deform_attn_forward_wrapper);
// }
\ No newline at end of file
groundingdino/models/GroundingDINO/csrc/MsDeformAttn/ms_deform_im2col_cuda.cuh
0 → 100644
View file @
34e4011b
/*!
**************************************************************************
* Deformable DETR
* Copyright (c) 2020 SenseTime. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
**************************************************************************
* Modified from DCN (https://github.com/msracver/Deformable-ConvNets)
* Copyright (c) 2018 Microsoft
**************************************************************************
*/
#include <cstdio>
#include <algorithm>
#include <cstring>
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <THC/THCAtomics.cuh>
#define CUDA_KERNEL_LOOP(i, n) \
for (int i = blockIdx.x * blockDim.x + threadIdx.x; \
i < (n); \
i += blockDim.x * gridDim.x)
const
int
CUDA_NUM_THREADS
=
1024
;
inline
int
GET_BLOCKS
(
const
int
N
,
const
int
num_threads
)
{
return
(
N
+
num_threads
-
1
)
/
num_threads
;
}
template
<
typename
scalar_t
>
__device__
scalar_t
ms_deform_attn_im2col_bilinear
(
const
scalar_t
*
&
bottom_data
,
const
int
&
height
,
const
int
&
width
,
const
int
&
nheads
,
const
int
&
channels
,
const
scalar_t
&
h
,
const
scalar_t
&
w
,
const
int
&
m
,
const
int
&
c
)
{
const
int
h_low
=
floor
(
h
);
const
int
w_low
=
floor
(
w
);
const
int
h_high
=
h_low
+
1
;
const
int
w_high
=
w_low
+
1
;
const
scalar_t
lh
=
h
-
h_low
;
const
scalar_t
lw
=
w
-
w_low
;
const
scalar_t
hh
=
1
-
lh
,
hw
=
1
-
lw
;
const
int
w_stride
=
nheads
*
channels
;
const
int
h_stride
=
width
*
w_stride
;
const
int
h_low_ptr_offset
=
h_low
*
h_stride
;
const
int
h_high_ptr_offset
=
h_low_ptr_offset
+
h_stride
;
const
int
w_low_ptr_offset
=
w_low
*
w_stride
;
const
int
w_high_ptr_offset
=
w_low_ptr_offset
+
w_stride
;
const
int
base_ptr
=
m
*
channels
+
c
;
scalar_t
v1
=
0
;
if
(
h_low
>=
0
&&
w_low
>=
0
)
{
const
int
ptr1
=
h_low_ptr_offset
+
w_low_ptr_offset
+
base_ptr
;
v1
=
bottom_data
[
ptr1
];
}
scalar_t
v2
=
0
;
if
(
h_low
>=
0
&&
w_high
<=
width
-
1
)
{
const
int
ptr2
=
h_low_ptr_offset
+
w_high_ptr_offset
+
base_ptr
;
v2
=
bottom_data
[
ptr2
];
}
scalar_t
v3
=
0
;
if
(
h_high
<=
height
-
1
&&
w_low
>=
0
)
{
const
int
ptr3
=
h_high_ptr_offset
+
w_low_ptr_offset
+
base_ptr
;
v3
=
bottom_data
[
ptr3
];
}
scalar_t
v4
=
0
;
if
(
h_high
<=
height
-
1
&&
w_high
<=
width
-
1
)
{
const
int
ptr4
=
h_high_ptr_offset
+
w_high_ptr_offset
+
base_ptr
;
v4
=
bottom_data
[
ptr4
];
}
const
scalar_t
w1
=
hh
*
hw
,
w2
=
hh
*
lw
,
w3
=
lh
*
hw
,
w4
=
lh
*
lw
;
const
scalar_t
val
=
(
w1
*
v1
+
w2
*
v2
+
w3
*
v3
+
w4
*
v4
);
return
val
;
}
template
<
typename
scalar_t
>
__device__
void
ms_deform_attn_col2im_bilinear
(
const
scalar_t
*
&
bottom_data
,
const
int
&
height
,
const
int
&
width
,
const
int
&
nheads
,
const
int
&
channels
,
const
scalar_t
&
h
,
const
scalar_t
&
w
,
const
int
&
m
,
const
int
&
c
,
const
scalar_t
&
top_grad
,
const
scalar_t
&
attn_weight
,
scalar_t
*
&
grad_value
,
scalar_t
*
grad_sampling_loc
,
scalar_t
*
grad_attn_weight
)
{
const
int
h_low
=
floor
(
h
);
const
int
w_low
=
floor
(
w
);
const
int
h_high
=
h_low
+
1
;
const
int
w_high
=
w_low
+
1
;
const
scalar_t
lh
=
h
-
h_low
;
const
scalar_t
lw
=
w
-
w_low
;
const
scalar_t
hh
=
1
-
lh
,
hw
=
1
-
lw
;
const
int
w_stride
=
nheads
*
channels
;
const
int
h_stride
=
width
*
w_stride
;
const
int
h_low_ptr_offset
=
h_low
*
h_stride
;
const
int
h_high_ptr_offset
=
h_low_ptr_offset
+
h_stride
;
const
int
w_low_ptr_offset
=
w_low
*
w_stride
;
const
int
w_high_ptr_offset
=
w_low_ptr_offset
+
w_stride
;
const
int
base_ptr
=
m
*
channels
+
c
;
const
scalar_t
w1
=
hh
*
hw
,
w2
=
hh
*
lw
,
w3
=
lh
*
hw
,
w4
=
lh
*
lw
;
const
scalar_t
top_grad_value
=
top_grad
*
attn_weight
;
scalar_t
grad_h_weight
=
0
,
grad_w_weight
=
0
;
scalar_t
v1
=
0
;
if
(
h_low
>=
0
&&
w_low
>=
0
)
{
const
int
ptr1
=
h_low_ptr_offset
+
w_low_ptr_offset
+
base_ptr
;
v1
=
bottom_data
[
ptr1
];
grad_h_weight
-=
hw
*
v1
;
grad_w_weight
-=
hh
*
v1
;
atomicAdd
(
grad_value
+
ptr1
,
w1
*
top_grad_value
);
}
scalar_t
v2
=
0
;
if
(
h_low
>=
0
&&
w_high
<=
width
-
1
)
{
const
int
ptr2
=
h_low_ptr_offset
+
w_high_ptr_offset
+
base_ptr
;
v2
=
bottom_data
[
ptr2
];
grad_h_weight
-=
lw
*
v2
;
grad_w_weight
+=
hh
*
v2
;
atomicAdd
(
grad_value
+
ptr2
,
w2
*
top_grad_value
);
}
scalar_t
v3
=
0
;
if
(
h_high
<=
height
-
1
&&
w_low
>=
0
)
{
const
int
ptr3
=
h_high_ptr_offset
+
w_low_ptr_offset
+
base_ptr
;
v3
=
bottom_data
[
ptr3
];
grad_h_weight
+=
hw
*
v3
;
grad_w_weight
-=
lh
*
v3
;
atomicAdd
(
grad_value
+
ptr3
,
w3
*
top_grad_value
);
}
scalar_t
v4
=
0
;
if
(
h_high
<=
height
-
1
&&
w_high
<=
width
-
1
)
{
const
int
ptr4
=
h_high_ptr_offset
+
w_high_ptr_offset
+
base_ptr
;
v4
=
bottom_data
[
ptr4
];
grad_h_weight
+=
lw
*
v4
;
grad_w_weight
+=
lh
*
v4
;
atomicAdd
(
grad_value
+
ptr4
,
w4
*
top_grad_value
);
}
const
scalar_t
val
=
(
w1
*
v1
+
w2
*
v2
+
w3
*
v3
+
w4
*
v4
);
*
grad_attn_weight
=
top_grad
*
val
;
*
grad_sampling_loc
=
width
*
grad_w_weight
*
top_grad_value
;
*
(
grad_sampling_loc
+
1
)
=
height
*
grad_h_weight
*
top_grad_value
;
}
template
<
typename
scalar_t
>
__device__
void
ms_deform_attn_col2im_bilinear_gm
(
const
scalar_t
*
&
bottom_data
,
const
int
&
height
,
const
int
&
width
,
const
int
&
nheads
,
const
int
&
channels
,
const
scalar_t
&
h
,
const
scalar_t
&
w
,
const
int
&
m
,
const
int
&
c
,
const
scalar_t
&
top_grad
,
const
scalar_t
&
attn_weight
,
scalar_t
*
&
grad_value
,
scalar_t
*
grad_sampling_loc
,
scalar_t
*
grad_attn_weight
)
{
const
int
h_low
=
floor
(
h
);
const
int
w_low
=
floor
(
w
);
const
int
h_high
=
h_low
+
1
;
const
int
w_high
=
w_low
+
1
;
const
scalar_t
lh
=
h
-
h_low
;
const
scalar_t
lw
=
w
-
w_low
;
const
scalar_t
hh
=
1
-
lh
,
hw
=
1
-
lw
;
const
int
w_stride
=
nheads
*
channels
;
const
int
h_stride
=
width
*
w_stride
;
const
int
h_low_ptr_offset
=
h_low
*
h_stride
;
const
int
h_high_ptr_offset
=
h_low_ptr_offset
+
h_stride
;
const
int
w_low_ptr_offset
=
w_low
*
w_stride
;
const
int
w_high_ptr_offset
=
w_low_ptr_offset
+
w_stride
;
const
int
base_ptr
=
m
*
channels
+
c
;
const
scalar_t
w1
=
hh
*
hw
,
w2
=
hh
*
lw
,
w3
=
lh
*
hw
,
w4
=
lh
*
lw
;
const
scalar_t
top_grad_value
=
top_grad
*
attn_weight
;
scalar_t
grad_h_weight
=
0
,
grad_w_weight
=
0
;
scalar_t
v1
=
0
;
if
(
h_low
>=
0
&&
w_low
>=
0
)
{
const
int
ptr1
=
h_low_ptr_offset
+
w_low_ptr_offset
+
base_ptr
;
v1
=
bottom_data
[
ptr1
];
grad_h_weight
-=
hw
*
v1
;
grad_w_weight
-=
hh
*
v1
;
atomicAdd
(
grad_value
+
ptr1
,
w1
*
top_grad_value
);
}
scalar_t
v2
=
0
;
if
(
h_low
>=
0
&&
w_high
<=
width
-
1
)
{
const
int
ptr2
=
h_low_ptr_offset
+
w_high_ptr_offset
+
base_ptr
;
v2
=
bottom_data
[
ptr2
];
grad_h_weight
-=
lw
*
v2
;
grad_w_weight
+=
hh
*
v2
;
atomicAdd
(
grad_value
+
ptr2
,
w2
*
top_grad_value
);
}
scalar_t
v3
=
0
;
if
(
h_high
<=
height
-
1
&&
w_low
>=
0
)
{
const
int
ptr3
=
h_high_ptr_offset
+
w_low_ptr_offset
+
base_ptr
;
v3
=
bottom_data
[
ptr3
];
grad_h_weight
+=
hw
*
v3
;
grad_w_weight
-=
lh
*
v3
;
atomicAdd
(
grad_value
+
ptr3
,
w3
*
top_grad_value
);
}
scalar_t
v4
=
0
;
if
(
h_high
<=
height
-
1
&&
w_high
<=
width
-
1
)
{
const
int
ptr4
=
h_high_ptr_offset
+
w_high_ptr_offset
+
base_ptr
;
v4
=
bottom_data
[
ptr4
];
grad_h_weight
+=
lw
*
v4
;
grad_w_weight
+=
lh
*
v4
;
atomicAdd
(
grad_value
+
ptr4
,
w4
*
top_grad_value
);
}
const
scalar_t
val
=
(
w1
*
v1
+
w2
*
v2
+
w3
*
v3
+
w4
*
v4
);
atomicAdd
(
grad_attn_weight
,
top_grad
*
val
);
atomicAdd
(
grad_sampling_loc
,
width
*
grad_w_weight
*
top_grad_value
);
atomicAdd
(
grad_sampling_loc
+
1
,
height
*
grad_h_weight
*
top_grad_value
);
}
template
<
typename
scalar_t
>
__global__
void
ms_deformable_im2col_gpu_kernel
(
const
int
n
,
const
scalar_t
*
data_value
,
const
int64_t
*
data_spatial_shapes
,
const
int64_t
*
data_level_start_index
,
const
scalar_t
*
data_sampling_loc
,
const
scalar_t
*
data_attn_weight
,
const
int
batch_size
,
const
int
spatial_size
,
const
int
num_heads
,
const
int
channels
,
const
int
num_levels
,
const
int
num_query
,
const
int
num_point
,
scalar_t
*
data_col
)
{
CUDA_KERNEL_LOOP
(
index
,
n
)
{
int
_temp
=
index
;
const
int
c_col
=
_temp
%
channels
;
_temp
/=
channels
;
const
int
sampling_index
=
_temp
;
const
int
m_col
=
_temp
%
num_heads
;
_temp
/=
num_heads
;
const
int
q_col
=
_temp
%
num_query
;
_temp
/=
num_query
;
const
int
b_col
=
_temp
;
scalar_t
*
data_col_ptr
=
data_col
+
index
;
int
data_weight_ptr
=
sampling_index
*
num_levels
*
num_point
;
int
data_loc_w_ptr
=
data_weight_ptr
<<
1
;
const
int
qid_stride
=
num_heads
*
channels
;
const
int
data_value_ptr_init_offset
=
b_col
*
spatial_size
*
qid_stride
;
scalar_t
col
=
0
;
for
(
int
l_col
=
0
;
l_col
<
num_levels
;
++
l_col
)
{
const
int
level_start_id
=
data_level_start_index
[
l_col
];
const
int
spatial_h_ptr
=
l_col
<<
1
;
const
int
spatial_h
=
data_spatial_shapes
[
spatial_h_ptr
];
const
int
spatial_w
=
data_spatial_shapes
[
spatial_h_ptr
+
1
];
const
scalar_t
*
data_value_ptr
=
data_value
+
(
data_value_ptr_init_offset
+
level_start_id
*
qid_stride
);
for
(
int
p_col
=
0
;
p_col
<
num_point
;
++
p_col
)
{
const
scalar_t
loc_w
=
data_sampling_loc
[
data_loc_w_ptr
];
const
scalar_t
loc_h
=
data_sampling_loc
[
data_loc_w_ptr
+
1
];
const
scalar_t
weight
=
data_attn_weight
[
data_weight_ptr
];
const
scalar_t
h_im
=
loc_h
*
spatial_h
-
0.5
;
const
scalar_t
w_im
=
loc_w
*
spatial_w
-
0.5
;
if
(
h_im
>
-
1
&&
w_im
>
-
1
&&
h_im
<
spatial_h
&&
w_im
<
spatial_w
)
{
col
+=
ms_deform_attn_im2col_bilinear
(
data_value_ptr
,
spatial_h
,
spatial_w
,
num_heads
,
channels
,
h_im
,
w_im
,
m_col
,
c_col
)
*
weight
;
}
data_weight_ptr
+=
1
;
data_loc_w_ptr
+=
2
;
}
}
*
data_col_ptr
=
col
;
}
}
template
<
typename
scalar_t
,
unsigned
int
blockSize
>
__global__
void
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1
(
const
int
n
,
const
scalar_t
*
grad_col
,
const
scalar_t
*
data_value
,
const
int64_t
*
data_spatial_shapes
,
const
int64_t
*
data_level_start_index
,
const
scalar_t
*
data_sampling_loc
,
const
scalar_t
*
data_attn_weight
,
const
int
batch_size
,
const
int
spatial_size
,
const
int
num_heads
,
const
int
channels
,
const
int
num_levels
,
const
int
num_query
,
const
int
num_point
,
scalar_t
*
grad_value
,
scalar_t
*
grad_sampling_loc
,
scalar_t
*
grad_attn_weight
)
{
CUDA_KERNEL_LOOP
(
index
,
n
)
{
__shared__
scalar_t
cache_grad_sampling_loc
[
blockSize
*
2
];
__shared__
scalar_t
cache_grad_attn_weight
[
blockSize
];
unsigned
int
tid
=
threadIdx
.
x
;
int
_temp
=
index
;
const
int
c_col
=
_temp
%
channels
;
_temp
/=
channels
;
const
int
sampling_index
=
_temp
;
const
int
m_col
=
_temp
%
num_heads
;
_temp
/=
num_heads
;
const
int
q_col
=
_temp
%
num_query
;
_temp
/=
num_query
;
const
int
b_col
=
_temp
;
const
scalar_t
top_grad
=
grad_col
[
index
];
int
data_weight_ptr
=
sampling_index
*
num_levels
*
num_point
;
int
data_loc_w_ptr
=
data_weight_ptr
<<
1
;
const
int
grad_sampling_ptr
=
data_weight_ptr
;
grad_sampling_loc
+=
grad_sampling_ptr
<<
1
;
grad_attn_weight
+=
grad_sampling_ptr
;
const
int
grad_weight_stride
=
1
;
const
int
grad_loc_stride
=
2
;
const
int
qid_stride
=
num_heads
*
channels
;
const
int
data_value_ptr_init_offset
=
b_col
*
spatial_size
*
qid_stride
;
for
(
int
l_col
=
0
;
l_col
<
num_levels
;
++
l_col
)
{
const
int
level_start_id
=
data_level_start_index
[
l_col
];
const
int
spatial_h_ptr
=
l_col
<<
1
;
const
int
spatial_h
=
data_spatial_shapes
[
spatial_h_ptr
];
const
int
spatial_w
=
data_spatial_shapes
[
spatial_h_ptr
+
1
];
const
int
value_ptr_offset
=
data_value_ptr_init_offset
+
level_start_id
*
qid_stride
;
const
scalar_t
*
data_value_ptr
=
data_value
+
value_ptr_offset
;
scalar_t
*
grad_value_ptr
=
grad_value
+
value_ptr_offset
;
for
(
int
p_col
=
0
;
p_col
<
num_point
;
++
p_col
)
{
const
scalar_t
loc_w
=
data_sampling_loc
[
data_loc_w_ptr
];
const
scalar_t
loc_h
=
data_sampling_loc
[
data_loc_w_ptr
+
1
];
const
scalar_t
weight
=
data_attn_weight
[
data_weight_ptr
];
const
scalar_t
h_im
=
loc_h
*
spatial_h
-
0.5
;
const
scalar_t
w_im
=
loc_w
*
spatial_w
-
0.5
;
*
(
cache_grad_sampling_loc
+
(
threadIdx
.
x
<<
1
))
=
0
;
*
(
cache_grad_sampling_loc
+
((
threadIdx
.
x
<<
1
)
+
1
))
=
0
;
*
(
cache_grad_attn_weight
+
threadIdx
.
x
)
=
0
;
if
(
h_im
>
-
1
&&
w_im
>
-
1
&&
h_im
<
spatial_h
&&
w_im
<
spatial_w
)
{
ms_deform_attn_col2im_bilinear
(
data_value_ptr
,
spatial_h
,
spatial_w
,
num_heads
,
channels
,
h_im
,
w_im
,
m_col
,
c_col
,
top_grad
,
weight
,
grad_value_ptr
,
cache_grad_sampling_loc
+
(
threadIdx
.
x
<<
1
),
cache_grad_attn_weight
+
threadIdx
.
x
);
}
__syncthreads
();
if
(
tid
==
0
)
{
scalar_t
_grad_w
=
cache_grad_sampling_loc
[
0
],
_grad_h
=
cache_grad_sampling_loc
[
1
],
_grad_a
=
cache_grad_attn_weight
[
0
];
int
sid
=
2
;
for
(
unsigned
int
tid
=
1
;
tid
<
blockSize
;
++
tid
)
{
_grad_w
+=
cache_grad_sampling_loc
[
sid
];
_grad_h
+=
cache_grad_sampling_loc
[
sid
+
1
];
_grad_a
+=
cache_grad_attn_weight
[
tid
];
sid
+=
2
;
}
*
grad_sampling_loc
=
_grad_w
;
*
(
grad_sampling_loc
+
1
)
=
_grad_h
;
*
grad_attn_weight
=
_grad_a
;
}
__syncthreads
();
data_weight_ptr
+=
1
;
data_loc_w_ptr
+=
2
;
grad_attn_weight
+=
grad_weight_stride
;
grad_sampling_loc
+=
grad_loc_stride
;
}
}
}
}
template
<
typename
scalar_t
,
unsigned
int
blockSize
>
__global__
void
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2
(
const
int
n
,
const
scalar_t
*
grad_col
,
const
scalar_t
*
data_value
,
const
int64_t
*
data_spatial_shapes
,
const
int64_t
*
data_level_start_index
,
const
scalar_t
*
data_sampling_loc
,
const
scalar_t
*
data_attn_weight
,
const
int
batch_size
,
const
int
spatial_size
,
const
int
num_heads
,
const
int
channels
,
const
int
num_levels
,
const
int
num_query
,
const
int
num_point
,
scalar_t
*
grad_value
,
scalar_t
*
grad_sampling_loc
,
scalar_t
*
grad_attn_weight
)
{
CUDA_KERNEL_LOOP
(
index
,
n
)
{
__shared__
scalar_t
cache_grad_sampling_loc
[
blockSize
*
2
];
__shared__
scalar_t
cache_grad_attn_weight
[
blockSize
];
unsigned
int
tid
=
threadIdx
.
x
;
int
_temp
=
index
;
const
int
c_col
=
_temp
%
channels
;
_temp
/=
channels
;
const
int
sampling_index
=
_temp
;
const
int
m_col
=
_temp
%
num_heads
;
_temp
/=
num_heads
;
const
int
q_col
=
_temp
%
num_query
;
_temp
/=
num_query
;
const
int
b_col
=
_temp
;
const
scalar_t
top_grad
=
grad_col
[
index
];
int
data_weight_ptr
=
sampling_index
*
num_levels
*
num_point
;
int
data_loc_w_ptr
=
data_weight_ptr
<<
1
;
const
int
grad_sampling_ptr
=
data_weight_ptr
;
grad_sampling_loc
+=
grad_sampling_ptr
<<
1
;
grad_attn_weight
+=
grad_sampling_ptr
;
const
int
grad_weight_stride
=
1
;
const
int
grad_loc_stride
=
2
;
const
int
qid_stride
=
num_heads
*
channels
;
const
int
data_value_ptr_init_offset
=
b_col
*
spatial_size
*
qid_stride
;
for
(
int
l_col
=
0
;
l_col
<
num_levels
;
++
l_col
)
{
const
int
level_start_id
=
data_level_start_index
[
l_col
];
const
int
spatial_h_ptr
=
l_col
<<
1
;
const
int
spatial_h
=
data_spatial_shapes
[
spatial_h_ptr
];
const
int
spatial_w
=
data_spatial_shapes
[
spatial_h_ptr
+
1
];
const
int
value_ptr_offset
=
data_value_ptr_init_offset
+
level_start_id
*
qid_stride
;
const
scalar_t
*
data_value_ptr
=
data_value
+
value_ptr_offset
;
scalar_t
*
grad_value_ptr
=
grad_value
+
value_ptr_offset
;
for
(
int
p_col
=
0
;
p_col
<
num_point
;
++
p_col
)
{
const
scalar_t
loc_w
=
data_sampling_loc
[
data_loc_w_ptr
];
const
scalar_t
loc_h
=
data_sampling_loc
[
data_loc_w_ptr
+
1
];
const
scalar_t
weight
=
data_attn_weight
[
data_weight_ptr
];
const
scalar_t
h_im
=
loc_h
*
spatial_h
-
0.5
;
const
scalar_t
w_im
=
loc_w
*
spatial_w
-
0.5
;
*
(
cache_grad_sampling_loc
+
(
threadIdx
.
x
<<
1
))
=
0
;
*
(
cache_grad_sampling_loc
+
((
threadIdx
.
x
<<
1
)
+
1
))
=
0
;
*
(
cache_grad_attn_weight
+
threadIdx
.
x
)
=
0
;
if
(
h_im
>
-
1
&&
w_im
>
-
1
&&
h_im
<
spatial_h
&&
w_im
<
spatial_w
)
{
ms_deform_attn_col2im_bilinear
(
data_value_ptr
,
spatial_h
,
spatial_w
,
num_heads
,
channels
,
h_im
,
w_im
,
m_col
,
c_col
,
top_grad
,
weight
,
grad_value_ptr
,
cache_grad_sampling_loc
+
(
threadIdx
.
x
<<
1
),
cache_grad_attn_weight
+
threadIdx
.
x
);
}
__syncthreads
();
for
(
unsigned
int
s
=
blockSize
/
2
;
s
>
0
;
s
>>=
1
)
{
if
(
tid
<
s
)
{
const
unsigned
int
xid1
=
tid
<<
1
;
const
unsigned
int
xid2
=
(
tid
+
s
)
<<
1
;
cache_grad_attn_weight
[
tid
]
+=
cache_grad_attn_weight
[
tid
+
s
];
cache_grad_sampling_loc
[
xid1
]
+=
cache_grad_sampling_loc
[
xid2
];
cache_grad_sampling_loc
[
xid1
+
1
]
+=
cache_grad_sampling_loc
[
xid2
+
1
];
}
__syncthreads
();
}
if
(
tid
==
0
)
{
*
grad_sampling_loc
=
cache_grad_sampling_loc
[
0
];
*
(
grad_sampling_loc
+
1
)
=
cache_grad_sampling_loc
[
1
];
*
grad_attn_weight
=
cache_grad_attn_weight
[
0
];
}
__syncthreads
();
data_weight_ptr
+=
1
;
data_loc_w_ptr
+=
2
;
grad_attn_weight
+=
grad_weight_stride
;
grad_sampling_loc
+=
grad_loc_stride
;
}
}
}
}
template
<
typename
scalar_t
>
__global__
void
ms_deformable_col2im_gpu_kernel_shm_reduce_v1
(
const
int
n
,
const
scalar_t
*
grad_col
,
const
scalar_t
*
data_value
,
const
int64_t
*
data_spatial_shapes
,
const
int64_t
*
data_level_start_index
,
const
scalar_t
*
data_sampling_loc
,
const
scalar_t
*
data_attn_weight
,
const
int
batch_size
,
const
int
spatial_size
,
const
int
num_heads
,
const
int
channels
,
const
int
num_levels
,
const
int
num_query
,
const
int
num_point
,
scalar_t
*
grad_value
,
scalar_t
*
grad_sampling_loc
,
scalar_t
*
grad_attn_weight
)
{
CUDA_KERNEL_LOOP
(
index
,
n
)
{
extern
__shared__
int
_s
[];
scalar_t
*
cache_grad_sampling_loc
=
(
scalar_t
*
)
_s
;
scalar_t
*
cache_grad_attn_weight
=
cache_grad_sampling_loc
+
2
*
blockDim
.
x
;
unsigned
int
tid
=
threadIdx
.
x
;
int
_temp
=
index
;
const
int
c_col
=
_temp
%
channels
;
_temp
/=
channels
;
const
int
sampling_index
=
_temp
;
const
int
m_col
=
_temp
%
num_heads
;
_temp
/=
num_heads
;
const
int
q_col
=
_temp
%
num_query
;
_temp
/=
num_query
;
const
int
b_col
=
_temp
;
const
scalar_t
top_grad
=
grad_col
[
index
];
int
data_weight_ptr
=
sampling_index
*
num_levels
*
num_point
;
int
data_loc_w_ptr
=
data_weight_ptr
<<
1
;
const
int
grad_sampling_ptr
=
data_weight_ptr
;
grad_sampling_loc
+=
grad_sampling_ptr
<<
1
;
grad_attn_weight
+=
grad_sampling_ptr
;
const
int
grad_weight_stride
=
1
;
const
int
grad_loc_stride
=
2
;
const
int
qid_stride
=
num_heads
*
channels
;
const
int
data_value_ptr_init_offset
=
b_col
*
spatial_size
*
qid_stride
;
for
(
int
l_col
=
0
;
l_col
<
num_levels
;
++
l_col
)
{
const
int
level_start_id
=
data_level_start_index
[
l_col
];
const
int
spatial_h_ptr
=
l_col
<<
1
;
const
int
spatial_h
=
data_spatial_shapes
[
spatial_h_ptr
];
const
int
spatial_w
=
data_spatial_shapes
[
spatial_h_ptr
+
1
];
const
int
value_ptr_offset
=
data_value_ptr_init_offset
+
level_start_id
*
qid_stride
;
const
scalar_t
*
data_value_ptr
=
data_value
+
value_ptr_offset
;
scalar_t
*
grad_value_ptr
=
grad_value
+
value_ptr_offset
;
for
(
int
p_col
=
0
;
p_col
<
num_point
;
++
p_col
)
{
const
scalar_t
loc_w
=
data_sampling_loc
[
data_loc_w_ptr
];
const
scalar_t
loc_h
=
data_sampling_loc
[
data_loc_w_ptr
+
1
];
const
scalar_t
weight
=
data_attn_weight
[
data_weight_ptr
];
const
scalar_t
h_im
=
loc_h
*
spatial_h
-
0.5
;
const
scalar_t
w_im
=
loc_w
*
spatial_w
-
0.5
;
*
(
cache_grad_sampling_loc
+
(
threadIdx
.
x
<<
1
))
=
0
;
*
(
cache_grad_sampling_loc
+
((
threadIdx
.
x
<<
1
)
+
1
))
=
0
;
*
(
cache_grad_attn_weight
+
threadIdx
.
x
)
=
0
;
if
(
h_im
>
-
1
&&
w_im
>
-
1
&&
h_im
<
spatial_h
&&
w_im
<
spatial_w
)
{
ms_deform_attn_col2im_bilinear
(
data_value_ptr
,
spatial_h
,
spatial_w
,
num_heads
,
channels
,
h_im
,
w_im
,
m_col
,
c_col
,
top_grad
,
weight
,
grad_value_ptr
,
cache_grad_sampling_loc
+
(
threadIdx
.
x
<<
1
),
cache_grad_attn_weight
+
threadIdx
.
x
);
}
__syncthreads
();
if
(
tid
==
0
)
{
scalar_t
_grad_w
=
cache_grad_sampling_loc
[
0
],
_grad_h
=
cache_grad_sampling_loc
[
1
],
_grad_a
=
cache_grad_attn_weight
[
0
];
int
sid
=
2
;
for
(
unsigned
int
tid
=
1
;
tid
<
blockDim
.
x
;
++
tid
)
{
_grad_w
+=
cache_grad_sampling_loc
[
sid
];
_grad_h
+=
cache_grad_sampling_loc
[
sid
+
1
];
_grad_a
+=
cache_grad_attn_weight
[
tid
];
sid
+=
2
;
}
*
grad_sampling_loc
=
_grad_w
;
*
(
grad_sampling_loc
+
1
)
=
_grad_h
;
*
grad_attn_weight
=
_grad_a
;
}
__syncthreads
();
data_weight_ptr
+=
1
;
data_loc_w_ptr
+=
2
;
grad_attn_weight
+=
grad_weight_stride
;
grad_sampling_loc
+=
grad_loc_stride
;
}
}
}
}
template
<
typename
scalar_t
>
__global__
void
ms_deformable_col2im_gpu_kernel_shm_reduce_v2
(
const
int
n
,
const
scalar_t
*
grad_col
,
const
scalar_t
*
data_value
,
const
int64_t
*
data_spatial_shapes
,
const
int64_t
*
data_level_start_index
,
const
scalar_t
*
data_sampling_loc
,
const
scalar_t
*
data_attn_weight
,
const
int
batch_size
,
const
int
spatial_size
,
const
int
num_heads
,
const
int
channels
,
const
int
num_levels
,
const
int
num_query
,
const
int
num_point
,
scalar_t
*
grad_value
,
scalar_t
*
grad_sampling_loc
,
scalar_t
*
grad_attn_weight
)
{
CUDA_KERNEL_LOOP
(
index
,
n
)
{
extern
__shared__
int
_s
[];
scalar_t
*
cache_grad_sampling_loc
=
(
scalar_t
*
)
_s
;
scalar_t
*
cache_grad_attn_weight
=
cache_grad_sampling_loc
+
2
*
blockDim
.
x
;
unsigned
int
tid
=
threadIdx
.
x
;
int
_temp
=
index
;
const
int
c_col
=
_temp
%
channels
;
_temp
/=
channels
;
const
int
sampling_index
=
_temp
;
const
int
m_col
=
_temp
%
num_heads
;
_temp
/=
num_heads
;
const
int
q_col
=
_temp
%
num_query
;
_temp
/=
num_query
;
const
int
b_col
=
_temp
;
const
scalar_t
top_grad
=
grad_col
[
index
];
int
data_weight_ptr
=
sampling_index
*
num_levels
*
num_point
;
int
data_loc_w_ptr
=
data_weight_ptr
<<
1
;
const
int
grad_sampling_ptr
=
data_weight_ptr
;
grad_sampling_loc
+=
grad_sampling_ptr
<<
1
;
grad_attn_weight
+=
grad_sampling_ptr
;
const
int
grad_weight_stride
=
1
;
const
int
grad_loc_stride
=
2
;
const
int
qid_stride
=
num_heads
*
channels
;
const
int
data_value_ptr_init_offset
=
b_col
*
spatial_size
*
qid_stride
;
for
(
int
l_col
=
0
;
l_col
<
num_levels
;
++
l_col
)
{
const
int
level_start_id
=
data_level_start_index
[
l_col
];
const
int
spatial_h_ptr
=
l_col
<<
1
;
const
int
spatial_h
=
data_spatial_shapes
[
spatial_h_ptr
];
const
int
spatial_w
=
data_spatial_shapes
[
spatial_h_ptr
+
1
];
const
int
value_ptr_offset
=
data_value_ptr_init_offset
+
level_start_id
*
qid_stride
;
const
scalar_t
*
data_value_ptr
=
data_value
+
value_ptr_offset
;
scalar_t
*
grad_value_ptr
=
grad_value
+
value_ptr_offset
;
for
(
int
p_col
=
0
;
p_col
<
num_point
;
++
p_col
)
{
const
scalar_t
loc_w
=
data_sampling_loc
[
data_loc_w_ptr
];
const
scalar_t
loc_h
=
data_sampling_loc
[
data_loc_w_ptr
+
1
];
const
scalar_t
weight
=
data_attn_weight
[
data_weight_ptr
];
const
scalar_t
h_im
=
loc_h
*
spatial_h
-
0.5
;
const
scalar_t
w_im
=
loc_w
*
spatial_w
-
0.5
;
*
(
cache_grad_sampling_loc
+
(
threadIdx
.
x
<<
1
))
=
0
;
*
(
cache_grad_sampling_loc
+
((
threadIdx
.
x
<<
1
)
+
1
))
=
0
;
*
(
cache_grad_attn_weight
+
threadIdx
.
x
)
=
0
;
if
(
h_im
>
-
1
&&
w_im
>
-
1
&&
h_im
<
spatial_h
&&
w_im
<
spatial_w
)
{
ms_deform_attn_col2im_bilinear
(
data_value_ptr
,
spatial_h
,
spatial_w
,
num_heads
,
channels
,
h_im
,
w_im
,
m_col
,
c_col
,
top_grad
,
weight
,
grad_value_ptr
,
cache_grad_sampling_loc
+
(
threadIdx
.
x
<<
1
),
cache_grad_attn_weight
+
threadIdx
.
x
);
}
__syncthreads
();
for
(
unsigned
int
s
=
blockDim
.
x
/
2
,
spre
=
blockDim
.
x
;
s
>
0
;
s
>>=
1
,
spre
>>=
1
)
{
if
(
tid
<
s
)
{
const
unsigned
int
xid1
=
tid
<<
1
;
const
unsigned
int
xid2
=
(
tid
+
s
)
<<
1
;
cache_grad_attn_weight
[
tid
]
+=
cache_grad_attn_weight
[
tid
+
s
];
cache_grad_sampling_loc
[
xid1
]
+=
cache_grad_sampling_loc
[
xid2
];
cache_grad_sampling_loc
[
xid1
+
1
]
+=
cache_grad_sampling_loc
[
xid2
+
1
];
if
(
tid
+
(
s
<<
1
)
<
spre
)
{
cache_grad_attn_weight
[
tid
]
+=
cache_grad_attn_weight
[
tid
+
(
s
<<
1
)];
cache_grad_sampling_loc
[
xid1
]
+=
cache_grad_sampling_loc
[
xid2
+
(
s
<<
1
)];
cache_grad_sampling_loc
[
xid1
+
1
]
+=
cache_grad_sampling_loc
[
xid2
+
1
+
(
s
<<
1
)];
}
}
__syncthreads
();
}
if
(
tid
==
0
)
{
*
grad_sampling_loc
=
cache_grad_sampling_loc
[
0
];
*
(
grad_sampling_loc
+
1
)
=
cache_grad_sampling_loc
[
1
];
*
grad_attn_weight
=
cache_grad_attn_weight
[
0
];
}
__syncthreads
();
data_weight_ptr
+=
1
;
data_loc_w_ptr
+=
2
;
grad_attn_weight
+=
grad_weight_stride
;
grad_sampling_loc
+=
grad_loc_stride
;
}
}
}
}
template
<
typename
scalar_t
>
__global__
void
ms_deformable_col2im_gpu_kernel_shm_reduce_v2_multi_blocks
(
const
int
n
,
const
scalar_t
*
grad_col
,
const
scalar_t
*
data_value
,
const
int64_t
*
data_spatial_shapes
,
const
int64_t
*
data_level_start_index
,
const
scalar_t
*
data_sampling_loc
,
const
scalar_t
*
data_attn_weight
,
const
int
batch_size
,
const
int
spatial_size
,
const
int
num_heads
,
const
int
channels
,
const
int
num_levels
,
const
int
num_query
,
const
int
num_point
,
scalar_t
*
grad_value
,
scalar_t
*
grad_sampling_loc
,
scalar_t
*
grad_attn_weight
)
{
CUDA_KERNEL_LOOP
(
index
,
n
)
{
extern
__shared__
int
_s
[];
scalar_t
*
cache_grad_sampling_loc
=
(
scalar_t
*
)
_s
;
scalar_t
*
cache_grad_attn_weight
=
cache_grad_sampling_loc
+
2
*
blockDim
.
x
;
unsigned
int
tid
=
threadIdx
.
x
;
int
_temp
=
index
;
const
int
c_col
=
_temp
%
channels
;
_temp
/=
channels
;
const
int
sampling_index
=
_temp
;
const
int
m_col
=
_temp
%
num_heads
;
_temp
/=
num_heads
;
const
int
q_col
=
_temp
%
num_query
;
_temp
/=
num_query
;
const
int
b_col
=
_temp
;
const
scalar_t
top_grad
=
grad_col
[
index
];
int
data_weight_ptr
=
sampling_index
*
num_levels
*
num_point
;
int
data_loc_w_ptr
=
data_weight_ptr
<<
1
;
const
int
grad_sampling_ptr
=
data_weight_ptr
;
grad_sampling_loc
+=
grad_sampling_ptr
<<
1
;
grad_attn_weight
+=
grad_sampling_ptr
;
const
int
grad_weight_stride
=
1
;
const
int
grad_loc_stride
=
2
;
const
int
qid_stride
=
num_heads
*
channels
;
const
int
data_value_ptr_init_offset
=
b_col
*
spatial_size
*
qid_stride
;
for
(
int
l_col
=
0
;
l_col
<
num_levels
;
++
l_col
)
{
const
int
level_start_id
=
data_level_start_index
[
l_col
];
const
int
spatial_h_ptr
=
l_col
<<
1
;
const
int
spatial_h
=
data_spatial_shapes
[
spatial_h_ptr
];
const
int
spatial_w
=
data_spatial_shapes
[
spatial_h_ptr
+
1
];
const
int
value_ptr_offset
=
data_value_ptr_init_offset
+
level_start_id
*
qid_stride
;
const
scalar_t
*
data_value_ptr
=
data_value
+
value_ptr_offset
;
scalar_t
*
grad_value_ptr
=
grad_value
+
value_ptr_offset
;
for
(
int
p_col
=
0
;
p_col
<
num_point
;
++
p_col
)
{
const
scalar_t
loc_w
=
data_sampling_loc
[
data_loc_w_ptr
];
const
scalar_t
loc_h
=
data_sampling_loc
[
data_loc_w_ptr
+
1
];
const
scalar_t
weight
=
data_attn_weight
[
data_weight_ptr
];
const
scalar_t
h_im
=
loc_h
*
spatial_h
-
0.5
;
const
scalar_t
w_im
=
loc_w
*
spatial_w
-
0.5
;
*
(
cache_grad_sampling_loc
+
(
threadIdx
.
x
<<
1
))
=
0
;
*
(
cache_grad_sampling_loc
+
((
threadIdx
.
x
<<
1
)
+
1
))
=
0
;
*
(
cache_grad_attn_weight
+
threadIdx
.
x
)
=
0
;
if
(
h_im
>
-
1
&&
w_im
>
-
1
&&
h_im
<
spatial_h
&&
w_im
<
spatial_w
)
{
ms_deform_attn_col2im_bilinear
(
data_value_ptr
,
spatial_h
,
spatial_w
,
num_heads
,
channels
,
h_im
,
w_im
,
m_col
,
c_col
,
top_grad
,
weight
,
grad_value_ptr
,
cache_grad_sampling_loc
+
(
threadIdx
.
x
<<
1
),
cache_grad_attn_weight
+
threadIdx
.
x
);
}
__syncthreads
();
for
(
unsigned
int
s
=
blockDim
.
x
/
2
,
spre
=
blockDim
.
x
;
s
>
0
;
s
>>=
1
,
spre
>>=
1
)
{
if
(
tid
<
s
)
{
const
unsigned
int
xid1
=
tid
<<
1
;
const
unsigned
int
xid2
=
(
tid
+
s
)
<<
1
;
cache_grad_attn_weight
[
tid
]
+=
cache_grad_attn_weight
[
tid
+
s
];
cache_grad_sampling_loc
[
xid1
]
+=
cache_grad_sampling_loc
[
xid2
];
cache_grad_sampling_loc
[
xid1
+
1
]
+=
cache_grad_sampling_loc
[
xid2
+
1
];
if
(
tid
+
(
s
<<
1
)
<
spre
)
{
cache_grad_attn_weight
[
tid
]
+=
cache_grad_attn_weight
[
tid
+
(
s
<<
1
)];
cache_grad_sampling_loc
[
xid1
]
+=
cache_grad_sampling_loc
[
xid2
+
(
s
<<
1
)];
cache_grad_sampling_loc
[
xid1
+
1
]
+=
cache_grad_sampling_loc
[
xid2
+
1
+
(
s
<<
1
)];
}
}
__syncthreads
();
}
if
(
tid
==
0
)
{
atomicAdd
(
grad_sampling_loc
,
cache_grad_sampling_loc
[
0
]);
atomicAdd
(
grad_sampling_loc
+
1
,
cache_grad_sampling_loc
[
1
]);
atomicAdd
(
grad_attn_weight
,
cache_grad_attn_weight
[
0
]);
}
__syncthreads
();
data_weight_ptr
+=
1
;
data_loc_w_ptr
+=
2
;
grad_attn_weight
+=
grad_weight_stride
;
grad_sampling_loc
+=
grad_loc_stride
;
}
}
}
}
template
<
typename
scalar_t
>
__global__
void
ms_deformable_col2im_gpu_kernel_gm
(
const
int
n
,
const
scalar_t
*
grad_col
,
const
scalar_t
*
data_value
,
const
int64_t
*
data_spatial_shapes
,
const
int64_t
*
data_level_start_index
,
const
scalar_t
*
data_sampling_loc
,
const
scalar_t
*
data_attn_weight
,
const
int
batch_size
,
const
int
spatial_size
,
const
int
num_heads
,
const
int
channels
,
const
int
num_levels
,
const
int
num_query
,
const
int
num_point
,
scalar_t
*
grad_value
,
scalar_t
*
grad_sampling_loc
,
scalar_t
*
grad_attn_weight
)
{
CUDA_KERNEL_LOOP
(
index
,
n
)
{
int
_temp
=
index
;
const
int
c_col
=
_temp
%
channels
;
_temp
/=
channels
;
const
int
sampling_index
=
_temp
;
const
int
m_col
=
_temp
%
num_heads
;
_temp
/=
num_heads
;
const
int
q_col
=
_temp
%
num_query
;
_temp
/=
num_query
;
const
int
b_col
=
_temp
;
const
scalar_t
top_grad
=
grad_col
[
index
];
int
data_weight_ptr
=
sampling_index
*
num_levels
*
num_point
;
int
data_loc_w_ptr
=
data_weight_ptr
<<
1
;
const
int
grad_sampling_ptr
=
data_weight_ptr
;
grad_sampling_loc
+=
grad_sampling_ptr
<<
1
;
grad_attn_weight
+=
grad_sampling_ptr
;
const
int
grad_weight_stride
=
1
;
const
int
grad_loc_stride
=
2
;
const
int
qid_stride
=
num_heads
*
channels
;
const
int
data_value_ptr_init_offset
=
b_col
*
spatial_size
*
qid_stride
;
for
(
int
l_col
=
0
;
l_col
<
num_levels
;
++
l_col
)
{
const
int
level_start_id
=
data_level_start_index
[
l_col
];
const
int
spatial_h_ptr
=
l_col
<<
1
;
const
int
spatial_h
=
data_spatial_shapes
[
spatial_h_ptr
];
const
int
spatial_w
=
data_spatial_shapes
[
spatial_h_ptr
+
1
];
const
int
value_ptr_offset
=
data_value_ptr_init_offset
+
level_start_id
*
qid_stride
;
const
scalar_t
*
data_value_ptr
=
data_value
+
value_ptr_offset
;
scalar_t
*
grad_value_ptr
=
grad_value
+
value_ptr_offset
;
for
(
int
p_col
=
0
;
p_col
<
num_point
;
++
p_col
)
{
const
scalar_t
loc_w
=
data_sampling_loc
[
data_loc_w_ptr
];
const
scalar_t
loc_h
=
data_sampling_loc
[
data_loc_w_ptr
+
1
];
const
scalar_t
weight
=
data_attn_weight
[
data_weight_ptr
];
const
scalar_t
h_im
=
loc_h
*
spatial_h
-
0.5
;
const
scalar_t
w_im
=
loc_w
*
spatial_w
-
0.5
;
if
(
h_im
>
-
1
&&
w_im
>
-
1
&&
h_im
<
spatial_h
&&
w_im
<
spatial_w
)
{
ms_deform_attn_col2im_bilinear_gm
(
data_value_ptr
,
spatial_h
,
spatial_w
,
num_heads
,
channels
,
h_im
,
w_im
,
m_col
,
c_col
,
top_grad
,
weight
,
grad_value_ptr
,
grad_sampling_loc
,
grad_attn_weight
);
}
data_weight_ptr
+=
1
;
data_loc_w_ptr
+=
2
;
grad_attn_weight
+=
grad_weight_stride
;
grad_sampling_loc
+=
grad_loc_stride
;
}
}
}
}
template
<
typename
scalar_t
>
void
ms_deformable_im2col_cuda
(
cudaStream_t
stream
,
const
scalar_t
*
data_value
,
const
int64_t
*
data_spatial_shapes
,
const
int64_t
*
data_level_start_index
,
const
scalar_t
*
data_sampling_loc
,
const
scalar_t
*
data_attn_weight
,
const
int
batch_size
,
const
int
spatial_size
,
const
int
num_heads
,
const
int
channels
,
const
int
num_levels
,
const
int
num_query
,
const
int
num_point
,
scalar_t
*
data_col
)
{
const
int
num_kernels
=
batch_size
*
num_query
*
num_heads
*
channels
;
const
int
num_actual_kernels
=
batch_size
*
num_query
*
num_heads
*
channels
;
const
int
num_threads
=
CUDA_NUM_THREADS
;
ms_deformable_im2col_gpu_kernel
<
scalar_t
>
<<<
GET_BLOCKS
(
num_actual_kernels
,
num_threads
),
num_threads
,
0
,
stream
>>>
(
num_kernels
,
data_value
,
data_spatial_shapes
,
data_level_start_index
,
data_sampling_loc
,
data_attn_weight
,
batch_size
,
spatial_size
,
num_heads
,
channels
,
num_levels
,
num_query
,
num_point
,
data_col
);
cudaError_t
err
=
cudaGetLastError
();
if
(
err
!=
cudaSuccess
)
{
printf
(
"error in ms_deformable_im2col_cuda: %s
\n
"
,
cudaGetErrorString
(
err
));
}
}
template
<
typename
scalar_t
>
void
ms_deformable_col2im_cuda
(
cudaStream_t
stream
,
const
scalar_t
*
grad_col
,
const
scalar_t
*
data_value
,
const
int64_t
*
data_spatial_shapes
,
const
int64_t
*
data_level_start_index
,
const
scalar_t
*
data_sampling_loc
,
const
scalar_t
*
data_attn_weight
,
const
int
batch_size
,
const
int
spatial_size
,
const
int
num_heads
,
const
int
channels
,
const
int
num_levels
,
const
int
num_query
,
const
int
num_point
,
scalar_t
*
grad_value
,
scalar_t
*
grad_sampling_loc
,
scalar_t
*
grad_attn_weight
)
{
const
int
num_threads
=
(
channels
>
CUDA_NUM_THREADS
)
?
CUDA_NUM_THREADS
:
channels
;
const
int
num_kernels
=
batch_size
*
num_query
*
num_heads
*
channels
;
const
int
num_actual_kernels
=
batch_size
*
num_query
*
num_heads
*
channels
;
if
(
channels
>
1024
)
{
if
((
channels
&
1023
)
==
0
)
{
ms_deformable_col2im_gpu_kernel_shm_reduce_v2_multi_blocks
<
scalar_t
>
<<<
GET_BLOCKS
(
num_actual_kernels
,
num_threads
),
num_threads
,
num_threads
*
3
*
sizeof
(
scalar_t
),
stream
>>>
(
num_kernels
,
grad_col
,
data_value
,
data_spatial_shapes
,
data_level_start_index
,
data_sampling_loc
,
data_attn_weight
,
batch_size
,
spatial_size
,
num_heads
,
channels
,
num_levels
,
num_query
,
num_point
,
grad_value
,
grad_sampling_loc
,
grad_attn_weight
);
}
else
{
ms_deformable_col2im_gpu_kernel_gm
<
scalar_t
>
<<<
GET_BLOCKS
(
num_actual_kernels
,
num_threads
),
num_threads
,
0
,
stream
>>>
(
num_kernels
,
grad_col
,
data_value
,
data_spatial_shapes
,
data_level_start_index
,
data_sampling_loc
,
data_attn_weight
,
batch_size
,
spatial_size
,
num_heads
,
channels
,
num_levels
,
num_query
,
num_point
,
grad_value
,
grad_sampling_loc
,
grad_attn_weight
);
}
}
else
{
switch
(
channels
)
{
case
1
:
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1
<
scalar_t
,
1
>
<<<
GET_BLOCKS
(
num_actual_kernels
,
num_threads
),
num_threads
,
0
,
stream
>>>
(
num_kernels
,
grad_col
,
data_value
,
data_spatial_shapes
,
data_level_start_index
,
data_sampling_loc
,
data_attn_weight
,
batch_size
,
spatial_size
,
num_heads
,
channels
,
num_levels
,
num_query
,
num_point
,
grad_value
,
grad_sampling_loc
,
grad_attn_weight
);
break
;
case
2
:
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1
<
scalar_t
,
2
>
<<<
GET_BLOCKS
(
num_actual_kernels
,
num_threads
),
num_threads
,
0
,
stream
>>>
(
num_kernels
,
grad_col
,
data_value
,
data_spatial_shapes
,
data_level_start_index
,
data_sampling_loc
,
data_attn_weight
,
batch_size
,
spatial_size
,
num_heads
,
channels
,
num_levels
,
num_query
,
num_point
,
grad_value
,
grad_sampling_loc
,
grad_attn_weight
);
break
;
case
4
:
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1
<
scalar_t
,
4
>
<<<
GET_BLOCKS
(
num_actual_kernels
,
num_threads
),
num_threads
,
0
,
stream
>>>
(
num_kernels
,
grad_col
,
data_value
,
data_spatial_shapes
,
data_level_start_index
,
data_sampling_loc
,
data_attn_weight
,
batch_size
,
spatial_size
,
num_heads
,
channels
,
num_levels
,
num_query
,
num_point
,
grad_value
,
grad_sampling_loc
,
grad_attn_weight
);
break
;
case
8
:
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1
<
scalar_t
,
8
>
<<<
GET_BLOCKS
(
num_actual_kernels
,
num_threads
),
num_threads
,
0
,
stream
>>>
(
num_kernels
,
grad_col
,
data_value
,
data_spatial_shapes
,
data_level_start_index
,
data_sampling_loc
,
data_attn_weight
,
batch_size
,
spatial_size
,
num_heads
,
channels
,
num_levels
,
num_query
,
num_point
,
grad_value
,
grad_sampling_loc
,
grad_attn_weight
);
break
;
case
16
:
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1
<
scalar_t
,
16
>
<<<
GET_BLOCKS
(
num_actual_kernels
,
num_threads
),
num_threads
,
0
,
stream
>>>
(
num_kernels
,
grad_col
,
data_value
,
data_spatial_shapes
,
data_level_start_index
,
data_sampling_loc
,
data_attn_weight
,
batch_size
,
spatial_size
,
num_heads
,
channels
,
num_levels
,
num_query
,
num_point
,
grad_value
,
grad_sampling_loc
,
grad_attn_weight
);
break
;
case
32
:
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1
<
scalar_t
,
32
>
<<<
GET_BLOCKS
(
num_actual_kernels
,
num_threads
),
num_threads
,
0
,
stream
>>>
(
num_kernels
,
grad_col
,
data_value
,
data_spatial_shapes
,
data_level_start_index
,
data_sampling_loc
,
data_attn_weight
,
batch_size
,
spatial_size
,
num_heads
,
channels
,
num_levels
,
num_query
,
num_point
,
grad_value
,
grad_sampling_loc
,
grad_attn_weight
);
break
;
case
64
:
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2
<
scalar_t
,
64
>
<<<
GET_BLOCKS
(
num_actual_kernels
,
num_threads
),
num_threads
,
0
,
stream
>>>
(
num_kernels
,
grad_col
,
data_value
,
data_spatial_shapes
,
data_level_start_index
,
data_sampling_loc
,
data_attn_weight
,
batch_size
,
spatial_size
,
num_heads
,
channels
,
num_levels
,
num_query
,
num_point
,
grad_value
,
grad_sampling_loc
,
grad_attn_weight
);
break
;
case
128
:
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2
<
scalar_t
,
128
>
<<<
GET_BLOCKS
(
num_actual_kernels
,
num_threads
),
num_threads
,
0
,
stream
>>>
(
num_kernels
,
grad_col
,
data_value
,
data_spatial_shapes
,
data_level_start_index
,
data_sampling_loc
,
data_attn_weight
,
batch_size
,
spatial_size
,
num_heads
,
channels
,
num_levels
,
num_query
,
num_point
,
grad_value
,
grad_sampling_loc
,
grad_attn_weight
);
break
;
case
256
:
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2
<
scalar_t
,
256
>
<<<
GET_BLOCKS
(
num_actual_kernels
,
num_threads
),
num_threads
,
0
,
stream
>>>
(
num_kernels
,
grad_col
,
data_value
,
data_spatial_shapes
,
data_level_start_index
,
data_sampling_loc
,
data_attn_weight
,
batch_size
,
spatial_size
,
num_heads
,
channels
,
num_levels
,
num_query
,
num_point
,
grad_value
,
grad_sampling_loc
,
grad_attn_weight
);
break
;
case
512
:
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2
<
scalar_t
,
512
>
<<<
GET_BLOCKS
(
num_actual_kernels
,
num_threads
),
num_threads
,
0
,
stream
>>>
(
num_kernels
,
grad_col
,
data_value
,
data_spatial_shapes
,
data_level_start_index
,
data_sampling_loc
,
data_attn_weight
,
batch_size
,
spatial_size
,
num_heads
,
channels
,
num_levels
,
num_query
,
num_point
,
grad_value
,
grad_sampling_loc
,
grad_attn_weight
);
break
;
case
1024
:
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2
<
scalar_t
,
1024
>
<<<
GET_BLOCKS
(
num_actual_kernels
,
num_threads
),
num_threads
,
0
,
stream
>>>
(
num_kernels
,
grad_col
,
data_value
,
data_spatial_shapes
,
data_level_start_index
,
data_sampling_loc
,
data_attn_weight
,
batch_size
,
spatial_size
,
num_heads
,
channels
,
num_levels
,
num_query
,
num_point
,
grad_value
,
grad_sampling_loc
,
grad_attn_weight
);
break
;
default:
if
(
channels
<
64
)
{
ms_deformable_col2im_gpu_kernel_shm_reduce_v1
<
scalar_t
>
<<<
GET_BLOCKS
(
num_actual_kernels
,
num_threads
),
num_threads
,
num_threads
*
3
*
sizeof
(
scalar_t
),
stream
>>>
(
num_kernels
,
grad_col
,
data_value
,
data_spatial_shapes
,
data_level_start_index
,
data_sampling_loc
,
data_attn_weight
,
batch_size
,
spatial_size
,
num_heads
,
channels
,
num_levels
,
num_query
,
num_point
,
grad_value
,
grad_sampling_loc
,
grad_attn_weight
);
}
else
{
ms_deformable_col2im_gpu_kernel_shm_reduce_v2
<
scalar_t
>
<<<
GET_BLOCKS
(
num_actual_kernels
,
num_threads
),
num_threads
,
num_threads
*
3
*
sizeof
(
scalar_t
),
stream
>>>
(
num_kernels
,
grad_col
,
data_value
,
data_spatial_shapes
,
data_level_start_index
,
data_sampling_loc
,
data_attn_weight
,
batch_size
,
spatial_size
,
num_heads
,
channels
,
num_levels
,
num_query
,
num_point
,
grad_value
,
grad_sampling_loc
,
grad_attn_weight
);
}
}
}
cudaError_t
err
=
cudaGetLastError
();
if
(
err
!=
cudaSuccess
)
{
printf
(
"error in ms_deformable_col2im_cuda: %s
\n
"
,
cudaGetErrorString
(
err
));
}
}
\ No newline at end of file
groundingdino/models/GroundingDINO/csrc/MsDeformAttn/ms_deform_im2col_cuda_hip.cuh
0 → 100644
View file @
34e4011b
// !!! This is a file automatically generated by hipify!!!
#include <ATen/dtk_macros.h>
#include "hip/hip_runtime.h"
/*!
**************************************************************************
* Deformable DETR
* Copyright (c) 2020 SenseTime. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
**************************************************************************
* Modified from DCN (https://github.com/msracver/Deformable-ConvNets)
* Copyright (c) 2018 Microsoft
**************************************************************************
*/
#include <cstdio>
#include <algorithm>
#include <cstring>
#include <ATen/ATen.h>
#include <ATen/hip/HIPContext.h>
#include <THH/THHAtomics.cuh>
#define CUDA_KERNEL_LOOP(i, n) \
for (int i = blockIdx.x * blockDim.x + threadIdx.x; \
i < (n); \
i += blockDim.x * gridDim.x)
const
int
CUDA_NUM_THREADS
=
1024
;
inline
int
GET_BLOCKS
(
const
int
N
,
const
int
num_threads
)
{
return
(
N
+
num_threads
-
1
)
/
num_threads
;
}
template
<
typename
scalar_t
>
__device__
scalar_t
ms_deform_attn_im2col_bilinear
(
const
scalar_t
*
&
bottom_data
,
const
int
&
height
,
const
int
&
width
,
const
int
&
nheads
,
const
int
&
channels
,
const
scalar_t
&
h
,
const
scalar_t
&
w
,
const
int
&
m
,
const
int
&
c
)
{
const
int
h_low
=
floor
(
h
);
const
int
w_low
=
floor
(
w
);
const
int
h_high
=
h_low
+
1
;
const
int
w_high
=
w_low
+
1
;
const
scalar_t
lh
=
h
-
h_low
;
const
scalar_t
lw
=
w
-
w_low
;
const
scalar_t
hh
=
1
-
lh
,
hw
=
1
-
lw
;
const
int
w_stride
=
nheads
*
channels
;
const
int
h_stride
=
width
*
w_stride
;
const
int
h_low_ptr_offset
=
h_low
*
h_stride
;
const
int
h_high_ptr_offset
=
h_low_ptr_offset
+
h_stride
;
const
int
w_low_ptr_offset
=
w_low
*
w_stride
;
const
int
w_high_ptr_offset
=
w_low_ptr_offset
+
w_stride
;
const
int
base_ptr
=
m
*
channels
+
c
;
scalar_t
v1
=
0
;
if
(
h_low
>=
0
&&
w_low
>=
0
)
{
const
int
ptr1
=
h_low_ptr_offset
+
w_low_ptr_offset
+
base_ptr
;
v1
=
bottom_data
[
ptr1
];
}
scalar_t
v2
=
0
;
if
(
h_low
>=
0
&&
w_high
<=
width
-
1
)
{
const
int
ptr2
=
h_low_ptr_offset
+
w_high_ptr_offset
+
base_ptr
;
v2
=
bottom_data
[
ptr2
];
}
scalar_t
v3
=
0
;
if
(
h_high
<=
height
-
1
&&
w_low
>=
0
)
{
const
int
ptr3
=
h_high_ptr_offset
+
w_low_ptr_offset
+
base_ptr
;
v3
=
bottom_data
[
ptr3
];
}
scalar_t
v4
=
0
;
if
(
h_high
<=
height
-
1
&&
w_high
<=
width
-
1
)
{
const
int
ptr4
=
h_high_ptr_offset
+
w_high_ptr_offset
+
base_ptr
;
v4
=
bottom_data
[
ptr4
];
}
const
scalar_t
w1
=
hh
*
hw
,
w2
=
hh
*
lw
,
w3
=
lh
*
hw
,
w4
=
lh
*
lw
;
const
scalar_t
val
=
(
w1
*
v1
+
w2
*
v2
+
w3
*
v3
+
w4
*
v4
);
return
val
;
}
template
<
typename
scalar_t
>
__device__
void
ms_deform_attn_col2im_bilinear
(
const
scalar_t
*
&
bottom_data
,
const
int
&
height
,
const
int
&
width
,
const
int
&
nheads
,
const
int
&
channels
,
const
scalar_t
&
h
,
const
scalar_t
&
w
,
const
int
&
m
,
const
int
&
c
,
const
scalar_t
&
top_grad
,
const
scalar_t
&
attn_weight
,
scalar_t
*
&
grad_value
,
scalar_t
*
grad_sampling_loc
,
scalar_t
*
grad_attn_weight
)
{
const
int
h_low
=
floor
(
h
);
const
int
w_low
=
floor
(
w
);
const
int
h_high
=
h_low
+
1
;
const
int
w_high
=
w_low
+
1
;
const
scalar_t
lh
=
h
-
h_low
;
const
scalar_t
lw
=
w
-
w_low
;
const
scalar_t
hh
=
1
-
lh
,
hw
=
1
-
lw
;
const
int
w_stride
=
nheads
*
channels
;
const
int
h_stride
=
width
*
w_stride
;
const
int
h_low_ptr_offset
=
h_low
*
h_stride
;
const
int
h_high_ptr_offset
=
h_low_ptr_offset
+
h_stride
;
const
int
w_low_ptr_offset
=
w_low
*
w_stride
;
const
int
w_high_ptr_offset
=
w_low_ptr_offset
+
w_stride
;
const
int
base_ptr
=
m
*
channels
+
c
;
const
scalar_t
w1
=
hh
*
hw
,
w2
=
hh
*
lw
,
w3
=
lh
*
hw
,
w4
=
lh
*
lw
;
const
scalar_t
top_grad_value
=
top_grad
*
attn_weight
;
scalar_t
grad_h_weight
=
0
,
grad_w_weight
=
0
;
scalar_t
v1
=
0
;
if
(
h_low
>=
0
&&
w_low
>=
0
)
{
const
int
ptr1
=
h_low_ptr_offset
+
w_low_ptr_offset
+
base_ptr
;
v1
=
bottom_data
[
ptr1
];
grad_h_weight
-=
hw
*
v1
;
grad_w_weight
-=
hh
*
v1
;
atomicAdd
(
grad_value
+
ptr1
,
w1
*
top_grad_value
);
}
scalar_t
v2
=
0
;
if
(
h_low
>=
0
&&
w_high
<=
width
-
1
)
{
const
int
ptr2
=
h_low_ptr_offset
+
w_high_ptr_offset
+
base_ptr
;
v2
=
bottom_data
[
ptr2
];
grad_h_weight
-=
lw
*
v2
;
grad_w_weight
+=
hh
*
v2
;
atomicAdd
(
grad_value
+
ptr2
,
w2
*
top_grad_value
);
}
scalar_t
v3
=
0
;
if
(
h_high
<=
height
-
1
&&
w_low
>=
0
)
{
const
int
ptr3
=
h_high_ptr_offset
+
w_low_ptr_offset
+
base_ptr
;
v3
=
bottom_data
[
ptr3
];
grad_h_weight
+=
hw
*
v3
;
grad_w_weight
-=
lh
*
v3
;
atomicAdd
(
grad_value
+
ptr3
,
w3
*
top_grad_value
);
}
scalar_t
v4
=
0
;
if
(
h_high
<=
height
-
1
&&
w_high
<=
width
-
1
)
{
const
int
ptr4
=
h_high_ptr_offset
+
w_high_ptr_offset
+
base_ptr
;
v4
=
bottom_data
[
ptr4
];
grad_h_weight
+=
lw
*
v4
;
grad_w_weight
+=
lh
*
v4
;
atomicAdd
(
grad_value
+
ptr4
,
w4
*
top_grad_value
);
}
const
scalar_t
val
=
(
w1
*
v1
+
w2
*
v2
+
w3
*
v3
+
w4
*
v4
);
*
grad_attn_weight
=
top_grad
*
val
;
*
grad_sampling_loc
=
width
*
grad_w_weight
*
top_grad_value
;
*
(
grad_sampling_loc
+
1
)
=
height
*
grad_h_weight
*
top_grad_value
;
}
template
<
typename
scalar_t
>
__device__
void
ms_deform_attn_col2im_bilinear_gm
(
const
scalar_t
*
&
bottom_data
,
const
int
&
height
,
const
int
&
width
,
const
int
&
nheads
,
const
int
&
channels
,
const
scalar_t
&
h
,
const
scalar_t
&
w
,
const
int
&
m
,
const
int
&
c
,
const
scalar_t
&
top_grad
,
const
scalar_t
&
attn_weight
,
scalar_t
*
&
grad_value
,
scalar_t
*
grad_sampling_loc
,
scalar_t
*
grad_attn_weight
)
{
const
int
h_low
=
floor
(
h
);
const
int
w_low
=
floor
(
w
);
const
int
h_high
=
h_low
+
1
;
const
int
w_high
=
w_low
+
1
;
const
scalar_t
lh
=
h
-
h_low
;
const
scalar_t
lw
=
w
-
w_low
;
const
scalar_t
hh
=
1
-
lh
,
hw
=
1
-
lw
;
const
int
w_stride
=
nheads
*
channels
;
const
int
h_stride
=
width
*
w_stride
;
const
int
h_low_ptr_offset
=
h_low
*
h_stride
;
const
int
h_high_ptr_offset
=
h_low_ptr_offset
+
h_stride
;
const
int
w_low_ptr_offset
=
w_low
*
w_stride
;
const
int
w_high_ptr_offset
=
w_low_ptr_offset
+
w_stride
;
const
int
base_ptr
=
m
*
channels
+
c
;
const
scalar_t
w1
=
hh
*
hw
,
w2
=
hh
*
lw
,
w3
=
lh
*
hw
,
w4
=
lh
*
lw
;
const
scalar_t
top_grad_value
=
top_grad
*
attn_weight
;
scalar_t
grad_h_weight
=
0
,
grad_w_weight
=
0
;
scalar_t
v1
=
0
;
if
(
h_low
>=
0
&&
w_low
>=
0
)
{
const
int
ptr1
=
h_low_ptr_offset
+
w_low_ptr_offset
+
base_ptr
;
v1
=
bottom_data
[
ptr1
];
grad_h_weight
-=
hw
*
v1
;
grad_w_weight
-=
hh
*
v1
;
atomicAdd
(
grad_value
+
ptr1
,
w1
*
top_grad_value
);
}
scalar_t
v2
=
0
;
if
(
h_low
>=
0
&&
w_high
<=
width
-
1
)
{
const
int
ptr2
=
h_low_ptr_offset
+
w_high_ptr_offset
+
base_ptr
;
v2
=
bottom_data
[
ptr2
];
grad_h_weight
-=
lw
*
v2
;
grad_w_weight
+=
hh
*
v2
;
atomicAdd
(
grad_value
+
ptr2
,
w2
*
top_grad_value
);
}
scalar_t
v3
=
0
;
if
(
h_high
<=
height
-
1
&&
w_low
>=
0
)
{
const
int
ptr3
=
h_high_ptr_offset
+
w_low_ptr_offset
+
base_ptr
;
v3
=
bottom_data
[
ptr3
];
grad_h_weight
+=
hw
*
v3
;
grad_w_weight
-=
lh
*
v3
;
atomicAdd
(
grad_value
+
ptr3
,
w3
*
top_grad_value
);
}
scalar_t
v4
=
0
;
if
(
h_high
<=
height
-
1
&&
w_high
<=
width
-
1
)
{
const
int
ptr4
=
h_high_ptr_offset
+
w_high_ptr_offset
+
base_ptr
;
v4
=
bottom_data
[
ptr4
];
grad_h_weight
+=
lw
*
v4
;
grad_w_weight
+=
lh
*
v4
;
atomicAdd
(
grad_value
+
ptr4
,
w4
*
top_grad_value
);
}
const
scalar_t
val
=
(
w1
*
v1
+
w2
*
v2
+
w3
*
v3
+
w4
*
v4
);
atomicAdd
(
grad_attn_weight
,
top_grad
*
val
);
atomicAdd
(
grad_sampling_loc
,
width
*
grad_w_weight
*
top_grad_value
);
atomicAdd
(
grad_sampling_loc
+
1
,
height
*
grad_h_weight
*
top_grad_value
);
}
template
<
typename
scalar_t
>
__global__
void
ms_deformable_im2col_gpu_kernel
(
const
int
n
,
const
scalar_t
*
data_value
,
const
int64_t
*
data_spatial_shapes
,
const
int64_t
*
data_level_start_index
,
const
scalar_t
*
data_sampling_loc
,
const
scalar_t
*
data_attn_weight
,
const
int
batch_size
,
const
int
spatial_size
,
const
int
num_heads
,
const
int
channels
,
const
int
num_levels
,
const
int
num_query
,
const
int
num_point
,
scalar_t
*
data_col
)
{
CUDA_KERNEL_LOOP
(
index
,
n
)
{
int
_temp
=
index
;
const
int
c_col
=
_temp
%
channels
;
_temp
/=
channels
;
const
int
sampling_index
=
_temp
;
const
int
m_col
=
_temp
%
num_heads
;
_temp
/=
num_heads
;
const
int
q_col
=
_temp
%
num_query
;
_temp
/=
num_query
;
const
int
b_col
=
_temp
;
scalar_t
*
data_col_ptr
=
data_col
+
index
;
int
data_weight_ptr
=
sampling_index
*
num_levels
*
num_point
;
int
data_loc_w_ptr
=
data_weight_ptr
<<
1
;
const
int
qid_stride
=
num_heads
*
channels
;
const
int
data_value_ptr_init_offset
=
b_col
*
spatial_size
*
qid_stride
;
scalar_t
col
=
0
;
for
(
int
l_col
=
0
;
l_col
<
num_levels
;
++
l_col
)
{
const
int
level_start_id
=
data_level_start_index
[
l_col
];
const
int
spatial_h_ptr
=
l_col
<<
1
;
const
int
spatial_h
=
data_spatial_shapes
[
spatial_h_ptr
];
const
int
spatial_w
=
data_spatial_shapes
[
spatial_h_ptr
+
1
];
const
scalar_t
*
data_value_ptr
=
data_value
+
(
data_value_ptr_init_offset
+
level_start_id
*
qid_stride
);
for
(
int
p_col
=
0
;
p_col
<
num_point
;
++
p_col
)
{
const
scalar_t
loc_w
=
data_sampling_loc
[
data_loc_w_ptr
];
const
scalar_t
loc_h
=
data_sampling_loc
[
data_loc_w_ptr
+
1
];
const
scalar_t
weight
=
data_attn_weight
[
data_weight_ptr
];
const
scalar_t
h_im
=
loc_h
*
spatial_h
-
0.5
;
const
scalar_t
w_im
=
loc_w
*
spatial_w
-
0.5
;
if
(
h_im
>
-
1
&&
w_im
>
-
1
&&
h_im
<
spatial_h
&&
w_im
<
spatial_w
)
{
col
+=
ms_deform_attn_im2col_bilinear
(
data_value_ptr
,
spatial_h
,
spatial_w
,
num_heads
,
channels
,
h_im
,
w_im
,
m_col
,
c_col
)
*
weight
;
}
data_weight_ptr
+=
1
;
data_loc_w_ptr
+=
2
;
}
}
*
data_col_ptr
=
col
;
}
}
template
<
typename
scalar_t
,
unsigned
int
blockSize
>
__global__
void
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1
(
const
int
n
,
const
scalar_t
*
grad_col
,
const
scalar_t
*
data_value
,
const
int64_t
*
data_spatial_shapes
,
const
int64_t
*
data_level_start_index
,
const
scalar_t
*
data_sampling_loc
,
const
scalar_t
*
data_attn_weight
,
const
int
batch_size
,
const
int
spatial_size
,
const
int
num_heads
,
const
int
channels
,
const
int
num_levels
,
const
int
num_query
,
const
int
num_point
,
scalar_t
*
grad_value
,
scalar_t
*
grad_sampling_loc
,
scalar_t
*
grad_attn_weight
)
{
CUDA_KERNEL_LOOP
(
index
,
n
)
{
__shared__
scalar_t
cache_grad_sampling_loc
[
blockSize
*
2
];
__shared__
scalar_t
cache_grad_attn_weight
[
blockSize
];
unsigned
int
tid
=
threadIdx
.
x
;
int
_temp
=
index
;
const
int
c_col
=
_temp
%
channels
;
_temp
/=
channels
;
const
int
sampling_index
=
_temp
;
const
int
m_col
=
_temp
%
num_heads
;
_temp
/=
num_heads
;
const
int
q_col
=
_temp
%
num_query
;
_temp
/=
num_query
;
const
int
b_col
=
_temp
;
const
scalar_t
top_grad
=
grad_col
[
index
];
int
data_weight_ptr
=
sampling_index
*
num_levels
*
num_point
;
int
data_loc_w_ptr
=
data_weight_ptr
<<
1
;
const
int
grad_sampling_ptr
=
data_weight_ptr
;
grad_sampling_loc
+=
grad_sampling_ptr
<<
1
;
grad_attn_weight
+=
grad_sampling_ptr
;
const
int
grad_weight_stride
=
1
;
const
int
grad_loc_stride
=
2
;
const
int
qid_stride
=
num_heads
*
channels
;
const
int
data_value_ptr_init_offset
=
b_col
*
spatial_size
*
qid_stride
;
for
(
int
l_col
=
0
;
l_col
<
num_levels
;
++
l_col
)
{
const
int
level_start_id
=
data_level_start_index
[
l_col
];
const
int
spatial_h_ptr
=
l_col
<<
1
;
const
int
spatial_h
=
data_spatial_shapes
[
spatial_h_ptr
];
const
int
spatial_w
=
data_spatial_shapes
[
spatial_h_ptr
+
1
];
const
int
value_ptr_offset
=
data_value_ptr_init_offset
+
level_start_id
*
qid_stride
;
const
scalar_t
*
data_value_ptr
=
data_value
+
value_ptr_offset
;
scalar_t
*
grad_value_ptr
=
grad_value
+
value_ptr_offset
;
for
(
int
p_col
=
0
;
p_col
<
num_point
;
++
p_col
)
{
const
scalar_t
loc_w
=
data_sampling_loc
[
data_loc_w_ptr
];
const
scalar_t
loc_h
=
data_sampling_loc
[
data_loc_w_ptr
+
1
];
const
scalar_t
weight
=
data_attn_weight
[
data_weight_ptr
];
const
scalar_t
h_im
=
loc_h
*
spatial_h
-
0.5
;
const
scalar_t
w_im
=
loc_w
*
spatial_w
-
0.5
;
*
(
cache_grad_sampling_loc
+
(
threadIdx
.
x
<<
1
))
=
0
;
*
(
cache_grad_sampling_loc
+
((
threadIdx
.
x
<<
1
)
+
1
))
=
0
;
*
(
cache_grad_attn_weight
+
threadIdx
.
x
)
=
0
;
if
(
h_im
>
-
1
&&
w_im
>
-
1
&&
h_im
<
spatial_h
&&
w_im
<
spatial_w
)
{
ms_deform_attn_col2im_bilinear
(
data_value_ptr
,
spatial_h
,
spatial_w
,
num_heads
,
channels
,
h_im
,
w_im
,
m_col
,
c_col
,
top_grad
,
weight
,
grad_value_ptr
,
cache_grad_sampling_loc
+
(
threadIdx
.
x
<<
1
),
cache_grad_attn_weight
+
threadIdx
.
x
);
}
__syncthreads
();
if
(
tid
==
0
)
{
scalar_t
_grad_w
=
cache_grad_sampling_loc
[
0
],
_grad_h
=
cache_grad_sampling_loc
[
1
],
_grad_a
=
cache_grad_attn_weight
[
0
];
int
sid
=
2
;
for
(
unsigned
int
tid
=
1
;
tid
<
blockSize
;
++
tid
)
{
_grad_w
+=
cache_grad_sampling_loc
[
sid
];
_grad_h
+=
cache_grad_sampling_loc
[
sid
+
1
];
_grad_a
+=
cache_grad_attn_weight
[
tid
];
sid
+=
2
;
}
*
grad_sampling_loc
=
_grad_w
;
*
(
grad_sampling_loc
+
1
)
=
_grad_h
;
*
grad_attn_weight
=
_grad_a
;
}
__syncthreads
();
data_weight_ptr
+=
1
;
data_loc_w_ptr
+=
2
;
grad_attn_weight
+=
grad_weight_stride
;
grad_sampling_loc
+=
grad_loc_stride
;
}
}
}
}
template
<
typename
scalar_t
,
unsigned
int
blockSize
>
__global__
void
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2
(
const
int
n
,
const
scalar_t
*
grad_col
,
const
scalar_t
*
data_value
,
const
int64_t
*
data_spatial_shapes
,
const
int64_t
*
data_level_start_index
,
const
scalar_t
*
data_sampling_loc
,
const
scalar_t
*
data_attn_weight
,
const
int
batch_size
,
const
int
spatial_size
,
const
int
num_heads
,
const
int
channels
,
const
int
num_levels
,
const
int
num_query
,
const
int
num_point
,
scalar_t
*
grad_value
,
scalar_t
*
grad_sampling_loc
,
scalar_t
*
grad_attn_weight
)
{
CUDA_KERNEL_LOOP
(
index
,
n
)
{
__shared__
scalar_t
cache_grad_sampling_loc
[
blockSize
*
2
];
__shared__
scalar_t
cache_grad_attn_weight
[
blockSize
];
unsigned
int
tid
=
threadIdx
.
x
;
int
_temp
=
index
;
const
int
c_col
=
_temp
%
channels
;
_temp
/=
channels
;
const
int
sampling_index
=
_temp
;
const
int
m_col
=
_temp
%
num_heads
;
_temp
/=
num_heads
;
const
int
q_col
=
_temp
%
num_query
;
_temp
/=
num_query
;
const
int
b_col
=
_temp
;
const
scalar_t
top_grad
=
grad_col
[
index
];
int
data_weight_ptr
=
sampling_index
*
num_levels
*
num_point
;
int
data_loc_w_ptr
=
data_weight_ptr
<<
1
;
const
int
grad_sampling_ptr
=
data_weight_ptr
;
grad_sampling_loc
+=
grad_sampling_ptr
<<
1
;
grad_attn_weight
+=
grad_sampling_ptr
;
const
int
grad_weight_stride
=
1
;
const
int
grad_loc_stride
=
2
;
const
int
qid_stride
=
num_heads
*
channels
;
const
int
data_value_ptr_init_offset
=
b_col
*
spatial_size
*
qid_stride
;
for
(
int
l_col
=
0
;
l_col
<
num_levels
;
++
l_col
)
{
const
int
level_start_id
=
data_level_start_index
[
l_col
];
const
int
spatial_h_ptr
=
l_col
<<
1
;
const
int
spatial_h
=
data_spatial_shapes
[
spatial_h_ptr
];
const
int
spatial_w
=
data_spatial_shapes
[
spatial_h_ptr
+
1
];
const
int
value_ptr_offset
=
data_value_ptr_init_offset
+
level_start_id
*
qid_stride
;
const
scalar_t
*
data_value_ptr
=
data_value
+
value_ptr_offset
;
scalar_t
*
grad_value_ptr
=
grad_value
+
value_ptr_offset
;
for
(
int
p_col
=
0
;
p_col
<
num_point
;
++
p_col
)
{
const
scalar_t
loc_w
=
data_sampling_loc
[
data_loc_w_ptr
];
const
scalar_t
loc_h
=
data_sampling_loc
[
data_loc_w_ptr
+
1
];
const
scalar_t
weight
=
data_attn_weight
[
data_weight_ptr
];
const
scalar_t
h_im
=
loc_h
*
spatial_h
-
0.5
;
const
scalar_t
w_im
=
loc_w
*
spatial_w
-
0.5
;
*
(
cache_grad_sampling_loc
+
(
threadIdx
.
x
<<
1
))
=
0
;
*
(
cache_grad_sampling_loc
+
((
threadIdx
.
x
<<
1
)
+
1
))
=
0
;
*
(
cache_grad_attn_weight
+
threadIdx
.
x
)
=
0
;
if
(
h_im
>
-
1
&&
w_im
>
-
1
&&
h_im
<
spatial_h
&&
w_im
<
spatial_w
)
{
ms_deform_attn_col2im_bilinear
(
data_value_ptr
,
spatial_h
,
spatial_w
,
num_heads
,
channels
,
h_im
,
w_im
,
m_col
,
c_col
,
top_grad
,
weight
,
grad_value_ptr
,
cache_grad_sampling_loc
+
(
threadIdx
.
x
<<
1
),
cache_grad_attn_weight
+
threadIdx
.
x
);
}
__syncthreads
();
for
(
unsigned
int
s
=
blockSize
/
2
;
s
>
0
;
s
>>=
1
)
{
if
(
tid
<
s
)
{
const
unsigned
int
xid1
=
tid
<<
1
;
const
unsigned
int
xid2
=
(
tid
+
s
)
<<
1
;
cache_grad_attn_weight
[
tid
]
+=
cache_grad_attn_weight
[
tid
+
s
];
cache_grad_sampling_loc
[
xid1
]
+=
cache_grad_sampling_loc
[
xid2
];
cache_grad_sampling_loc
[
xid1
+
1
]
+=
cache_grad_sampling_loc
[
xid2
+
1
];
}
__syncthreads
();
}
if
(
tid
==
0
)
{
*
grad_sampling_loc
=
cache_grad_sampling_loc
[
0
];
*
(
grad_sampling_loc
+
1
)
=
cache_grad_sampling_loc
[
1
];
*
grad_attn_weight
=
cache_grad_attn_weight
[
0
];
}
__syncthreads
();
data_weight_ptr
+=
1
;
data_loc_w_ptr
+=
2
;
grad_attn_weight
+=
grad_weight_stride
;
grad_sampling_loc
+=
grad_loc_stride
;
}
}
}
}
template
<
typename
scalar_t
>
__global__
void
ms_deformable_col2im_gpu_kernel_shm_reduce_v1
(
const
int
n
,
const
scalar_t
*
grad_col
,
const
scalar_t
*
data_value
,
const
int64_t
*
data_spatial_shapes
,
const
int64_t
*
data_level_start_index
,
const
scalar_t
*
data_sampling_loc
,
const
scalar_t
*
data_attn_weight
,
const
int
batch_size
,
const
int
spatial_size
,
const
int
num_heads
,
const
int
channels
,
const
int
num_levels
,
const
int
num_query
,
const
int
num_point
,
scalar_t
*
grad_value
,
scalar_t
*
grad_sampling_loc
,
scalar_t
*
grad_attn_weight
)
{
CUDA_KERNEL_LOOP
(
index
,
n
)
{
extern
__shared__
int
_s
[];
scalar_t
*
cache_grad_sampling_loc
=
(
scalar_t
*
)
_s
;
scalar_t
*
cache_grad_attn_weight
=
cache_grad_sampling_loc
+
2
*
blockDim
.
x
;
unsigned
int
tid
=
threadIdx
.
x
;
int
_temp
=
index
;
const
int
c_col
=
_temp
%
channels
;
_temp
/=
channels
;
const
int
sampling_index
=
_temp
;
const
int
m_col
=
_temp
%
num_heads
;
_temp
/=
num_heads
;
const
int
q_col
=
_temp
%
num_query
;
_temp
/=
num_query
;
const
int
b_col
=
_temp
;
const
scalar_t
top_grad
=
grad_col
[
index
];
int
data_weight_ptr
=
sampling_index
*
num_levels
*
num_point
;
int
data_loc_w_ptr
=
data_weight_ptr
<<
1
;
const
int
grad_sampling_ptr
=
data_weight_ptr
;
grad_sampling_loc
+=
grad_sampling_ptr
<<
1
;
grad_attn_weight
+=
grad_sampling_ptr
;
const
int
grad_weight_stride
=
1
;
const
int
grad_loc_stride
=
2
;
const
int
qid_stride
=
num_heads
*
channels
;
const
int
data_value_ptr_init_offset
=
b_col
*
spatial_size
*
qid_stride
;
for
(
int
l_col
=
0
;
l_col
<
num_levels
;
++
l_col
)
{
const
int
level_start_id
=
data_level_start_index
[
l_col
];
const
int
spatial_h_ptr
=
l_col
<<
1
;
const
int
spatial_h
=
data_spatial_shapes
[
spatial_h_ptr
];
const
int
spatial_w
=
data_spatial_shapes
[
spatial_h_ptr
+
1
];
const
int
value_ptr_offset
=
data_value_ptr_init_offset
+
level_start_id
*
qid_stride
;
const
scalar_t
*
data_value_ptr
=
data_value
+
value_ptr_offset
;
scalar_t
*
grad_value_ptr
=
grad_value
+
value_ptr_offset
;
for
(
int
p_col
=
0
;
p_col
<
num_point
;
++
p_col
)
{
const
scalar_t
loc_w
=
data_sampling_loc
[
data_loc_w_ptr
];
const
scalar_t
loc_h
=
data_sampling_loc
[
data_loc_w_ptr
+
1
];
const
scalar_t
weight
=
data_attn_weight
[
data_weight_ptr
];
const
scalar_t
h_im
=
loc_h
*
spatial_h
-
0.5
;
const
scalar_t
w_im
=
loc_w
*
spatial_w
-
0.5
;
*
(
cache_grad_sampling_loc
+
(
threadIdx
.
x
<<
1
))
=
0
;
*
(
cache_grad_sampling_loc
+
((
threadIdx
.
x
<<
1
)
+
1
))
=
0
;
*
(
cache_grad_attn_weight
+
threadIdx
.
x
)
=
0
;
if
(
h_im
>
-
1
&&
w_im
>
-
1
&&
h_im
<
spatial_h
&&
w_im
<
spatial_w
)
{
ms_deform_attn_col2im_bilinear
(
data_value_ptr
,
spatial_h
,
spatial_w
,
num_heads
,
channels
,
h_im
,
w_im
,
m_col
,
c_col
,
top_grad
,
weight
,
grad_value_ptr
,
cache_grad_sampling_loc
+
(
threadIdx
.
x
<<
1
),
cache_grad_attn_weight
+
threadIdx
.
x
);
}
__syncthreads
();
if
(
tid
==
0
)
{
scalar_t
_grad_w
=
cache_grad_sampling_loc
[
0
],
_grad_h
=
cache_grad_sampling_loc
[
1
],
_grad_a
=
cache_grad_attn_weight
[
0
];
int
sid
=
2
;
for
(
unsigned
int
tid
=
1
;
tid
<
blockDim
.
x
;
++
tid
)
{
_grad_w
+=
cache_grad_sampling_loc
[
sid
];
_grad_h
+=
cache_grad_sampling_loc
[
sid
+
1
];
_grad_a
+=
cache_grad_attn_weight
[
tid
];
sid
+=
2
;
}
*
grad_sampling_loc
=
_grad_w
;
*
(
grad_sampling_loc
+
1
)
=
_grad_h
;
*
grad_attn_weight
=
_grad_a
;
}
__syncthreads
();
data_weight_ptr
+=
1
;
data_loc_w_ptr
+=
2
;
grad_attn_weight
+=
grad_weight_stride
;
grad_sampling_loc
+=
grad_loc_stride
;
}
}
}
}
template
<
typename
scalar_t
>
__global__
void
ms_deformable_col2im_gpu_kernel_shm_reduce_v2
(
const
int
n
,
const
scalar_t
*
grad_col
,
const
scalar_t
*
data_value
,
const
int64_t
*
data_spatial_shapes
,
const
int64_t
*
data_level_start_index
,
const
scalar_t
*
data_sampling_loc
,
const
scalar_t
*
data_attn_weight
,
const
int
batch_size
,
const
int
spatial_size
,
const
int
num_heads
,
const
int
channels
,
const
int
num_levels
,
const
int
num_query
,
const
int
num_point
,
scalar_t
*
grad_value
,
scalar_t
*
grad_sampling_loc
,
scalar_t
*
grad_attn_weight
)
{
CUDA_KERNEL_LOOP
(
index
,
n
)
{
extern
__shared__
int
_s
[];
scalar_t
*
cache_grad_sampling_loc
=
(
scalar_t
*
)
_s
;
scalar_t
*
cache_grad_attn_weight
=
cache_grad_sampling_loc
+
2
*
blockDim
.
x
;
unsigned
int
tid
=
threadIdx
.
x
;
int
_temp
=
index
;
const
int
c_col
=
_temp
%
channels
;
_temp
/=
channels
;
const
int
sampling_index
=
_temp
;
const
int
m_col
=
_temp
%
num_heads
;
_temp
/=
num_heads
;
const
int
q_col
=
_temp
%
num_query
;
_temp
/=
num_query
;
const
int
b_col
=
_temp
;
const
scalar_t
top_grad
=
grad_col
[
index
];
int
data_weight_ptr
=
sampling_index
*
num_levels
*
num_point
;
int
data_loc_w_ptr
=
data_weight_ptr
<<
1
;
const
int
grad_sampling_ptr
=
data_weight_ptr
;
grad_sampling_loc
+=
grad_sampling_ptr
<<
1
;
grad_attn_weight
+=
grad_sampling_ptr
;
const
int
grad_weight_stride
=
1
;
const
int
grad_loc_stride
=
2
;
const
int
qid_stride
=
num_heads
*
channels
;
const
int
data_value_ptr_init_offset
=
b_col
*
spatial_size
*
qid_stride
;
for
(
int
l_col
=
0
;
l_col
<
num_levels
;
++
l_col
)
{
const
int
level_start_id
=
data_level_start_index
[
l_col
];
const
int
spatial_h_ptr
=
l_col
<<
1
;
const
int
spatial_h
=
data_spatial_shapes
[
spatial_h_ptr
];
const
int
spatial_w
=
data_spatial_shapes
[
spatial_h_ptr
+
1
];
const
int
value_ptr_offset
=
data_value_ptr_init_offset
+
level_start_id
*
qid_stride
;
const
scalar_t
*
data_value_ptr
=
data_value
+
value_ptr_offset
;
scalar_t
*
grad_value_ptr
=
grad_value
+
value_ptr_offset
;
for
(
int
p_col
=
0
;
p_col
<
num_point
;
++
p_col
)
{
const
scalar_t
loc_w
=
data_sampling_loc
[
data_loc_w_ptr
];
const
scalar_t
loc_h
=
data_sampling_loc
[
data_loc_w_ptr
+
1
];
const
scalar_t
weight
=
data_attn_weight
[
data_weight_ptr
];
const
scalar_t
h_im
=
loc_h
*
spatial_h
-
0.5
;
const
scalar_t
w_im
=
loc_w
*
spatial_w
-
0.5
;
*
(
cache_grad_sampling_loc
+
(
threadIdx
.
x
<<
1
))
=
0
;
*
(
cache_grad_sampling_loc
+
((
threadIdx
.
x
<<
1
)
+
1
))
=
0
;
*
(
cache_grad_attn_weight
+
threadIdx
.
x
)
=
0
;
if
(
h_im
>
-
1
&&
w_im
>
-
1
&&
h_im
<
spatial_h
&&
w_im
<
spatial_w
)
{
ms_deform_attn_col2im_bilinear
(
data_value_ptr
,
spatial_h
,
spatial_w
,
num_heads
,
channels
,
h_im
,
w_im
,
m_col
,
c_col
,
top_grad
,
weight
,
grad_value_ptr
,
cache_grad_sampling_loc
+
(
threadIdx
.
x
<<
1
),
cache_grad_attn_weight
+
threadIdx
.
x
);
}
__syncthreads
();
for
(
unsigned
int
s
=
blockDim
.
x
/
2
,
spre
=
blockDim
.
x
;
s
>
0
;
s
>>=
1
,
spre
>>=
1
)
{
if
(
tid
<
s
)
{
const
unsigned
int
xid1
=
tid
<<
1
;
const
unsigned
int
xid2
=
(
tid
+
s
)
<<
1
;
cache_grad_attn_weight
[
tid
]
+=
cache_grad_attn_weight
[
tid
+
s
];
cache_grad_sampling_loc
[
xid1
]
+=
cache_grad_sampling_loc
[
xid2
];
cache_grad_sampling_loc
[
xid1
+
1
]
+=
cache_grad_sampling_loc
[
xid2
+
1
];
if
(
tid
+
(
s
<<
1
)
<
spre
)
{
cache_grad_attn_weight
[
tid
]
+=
cache_grad_attn_weight
[
tid
+
(
s
<<
1
)];
cache_grad_sampling_loc
[
xid1
]
+=
cache_grad_sampling_loc
[
xid2
+
(
s
<<
1
)];
cache_grad_sampling_loc
[
xid1
+
1
]
+=
cache_grad_sampling_loc
[
xid2
+
1
+
(
s
<<
1
)];
}
}
__syncthreads
();
}
if
(
tid
==
0
)
{
*
grad_sampling_loc
=
cache_grad_sampling_loc
[
0
];
*
(
grad_sampling_loc
+
1
)
=
cache_grad_sampling_loc
[
1
];
*
grad_attn_weight
=
cache_grad_attn_weight
[
0
];
}
__syncthreads
();
data_weight_ptr
+=
1
;
data_loc_w_ptr
+=
2
;
grad_attn_weight
+=
grad_weight_stride
;
grad_sampling_loc
+=
grad_loc_stride
;
}
}
}
}
template
<
typename
scalar_t
>
__global__
void
ms_deformable_col2im_gpu_kernel_shm_reduce_v2_multi_blocks
(
const
int
n
,
const
scalar_t
*
grad_col
,
const
scalar_t
*
data_value
,
const
int64_t
*
data_spatial_shapes
,
const
int64_t
*
data_level_start_index
,
const
scalar_t
*
data_sampling_loc
,
const
scalar_t
*
data_attn_weight
,
const
int
batch_size
,
const
int
spatial_size
,
const
int
num_heads
,
const
int
channels
,
const
int
num_levels
,
const
int
num_query
,
const
int
num_point
,
scalar_t
*
grad_value
,
scalar_t
*
grad_sampling_loc
,
scalar_t
*
grad_attn_weight
)
{
CUDA_KERNEL_LOOP
(
index
,
n
)
{
extern
__shared__
int
_s
[];
scalar_t
*
cache_grad_sampling_loc
=
(
scalar_t
*
)
_s
;
scalar_t
*
cache_grad_attn_weight
=
cache_grad_sampling_loc
+
2
*
blockDim
.
x
;
unsigned
int
tid
=
threadIdx
.
x
;
int
_temp
=
index
;
const
int
c_col
=
_temp
%
channels
;
_temp
/=
channels
;
const
int
sampling_index
=
_temp
;
const
int
m_col
=
_temp
%
num_heads
;
_temp
/=
num_heads
;
const
int
q_col
=
_temp
%
num_query
;
_temp
/=
num_query
;
const
int
b_col
=
_temp
;
const
scalar_t
top_grad
=
grad_col
[
index
];
int
data_weight_ptr
=
sampling_index
*
num_levels
*
num_point
;
int
data_loc_w_ptr
=
data_weight_ptr
<<
1
;
const
int
grad_sampling_ptr
=
data_weight_ptr
;
grad_sampling_loc
+=
grad_sampling_ptr
<<
1
;
grad_attn_weight
+=
grad_sampling_ptr
;
const
int
grad_weight_stride
=
1
;
const
int
grad_loc_stride
=
2
;
const
int
qid_stride
=
num_heads
*
channels
;
const
int
data_value_ptr_init_offset
=
b_col
*
spatial_size
*
qid_stride
;
for
(
int
l_col
=
0
;
l_col
<
num_levels
;
++
l_col
)
{
const
int
level_start_id
=
data_level_start_index
[
l_col
];
const
int
spatial_h_ptr
=
l_col
<<
1
;
const
int
spatial_h
=
data_spatial_shapes
[
spatial_h_ptr
];
const
int
spatial_w
=
data_spatial_shapes
[
spatial_h_ptr
+
1
];
const
int
value_ptr_offset
=
data_value_ptr_init_offset
+
level_start_id
*
qid_stride
;
const
scalar_t
*
data_value_ptr
=
data_value
+
value_ptr_offset
;
scalar_t
*
grad_value_ptr
=
grad_value
+
value_ptr_offset
;
for
(
int
p_col
=
0
;
p_col
<
num_point
;
++
p_col
)
{
const
scalar_t
loc_w
=
data_sampling_loc
[
data_loc_w_ptr
];
const
scalar_t
loc_h
=
data_sampling_loc
[
data_loc_w_ptr
+
1
];
const
scalar_t
weight
=
data_attn_weight
[
data_weight_ptr
];
const
scalar_t
h_im
=
loc_h
*
spatial_h
-
0.5
;
const
scalar_t
w_im
=
loc_w
*
spatial_w
-
0.5
;
*
(
cache_grad_sampling_loc
+
(
threadIdx
.
x
<<
1
))
=
0
;
*
(
cache_grad_sampling_loc
+
((
threadIdx
.
x
<<
1
)
+
1
))
=
0
;
*
(
cache_grad_attn_weight
+
threadIdx
.
x
)
=
0
;
if
(
h_im
>
-
1
&&
w_im
>
-
1
&&
h_im
<
spatial_h
&&
w_im
<
spatial_w
)
{
ms_deform_attn_col2im_bilinear
(
data_value_ptr
,
spatial_h
,
spatial_w
,
num_heads
,
channels
,
h_im
,
w_im
,
m_col
,
c_col
,
top_grad
,
weight
,
grad_value_ptr
,
cache_grad_sampling_loc
+
(
threadIdx
.
x
<<
1
),
cache_grad_attn_weight
+
threadIdx
.
x
);
}
__syncthreads
();
for
(
unsigned
int
s
=
blockDim
.
x
/
2
,
spre
=
blockDim
.
x
;
s
>
0
;
s
>>=
1
,
spre
>>=
1
)
{
if
(
tid
<
s
)
{
const
unsigned
int
xid1
=
tid
<<
1
;
const
unsigned
int
xid2
=
(
tid
+
s
)
<<
1
;
cache_grad_attn_weight
[
tid
]
+=
cache_grad_attn_weight
[
tid
+
s
];
cache_grad_sampling_loc
[
xid1
]
+=
cache_grad_sampling_loc
[
xid2
];
cache_grad_sampling_loc
[
xid1
+
1
]
+=
cache_grad_sampling_loc
[
xid2
+
1
];
if
(
tid
+
(
s
<<
1
)
<
spre
)
{
cache_grad_attn_weight
[
tid
]
+=
cache_grad_attn_weight
[
tid
+
(
s
<<
1
)];
cache_grad_sampling_loc
[
xid1
]
+=
cache_grad_sampling_loc
[
xid2
+
(
s
<<
1
)];
cache_grad_sampling_loc
[
xid1
+
1
]
+=
cache_grad_sampling_loc
[
xid2
+
1
+
(
s
<<
1
)];
}
}
__syncthreads
();
}
if
(
tid
==
0
)
{
atomicAdd
(
grad_sampling_loc
,
cache_grad_sampling_loc
[
0
]);
atomicAdd
(
grad_sampling_loc
+
1
,
cache_grad_sampling_loc
[
1
]);
atomicAdd
(
grad_attn_weight
,
cache_grad_attn_weight
[
0
]);
}
__syncthreads
();
data_weight_ptr
+=
1
;
data_loc_w_ptr
+=
2
;
grad_attn_weight
+=
grad_weight_stride
;
grad_sampling_loc
+=
grad_loc_stride
;
}
}
}
}
template
<
typename
scalar_t
>
__global__
void
ms_deformable_col2im_gpu_kernel_gm
(
const
int
n
,
const
scalar_t
*
grad_col
,
const
scalar_t
*
data_value
,
const
int64_t
*
data_spatial_shapes
,
const
int64_t
*
data_level_start_index
,
const
scalar_t
*
data_sampling_loc
,
const
scalar_t
*
data_attn_weight
,
const
int
batch_size
,
const
int
spatial_size
,
const
int
num_heads
,
const
int
channels
,
const
int
num_levels
,
const
int
num_query
,
const
int
num_point
,
scalar_t
*
grad_value
,
scalar_t
*
grad_sampling_loc
,
scalar_t
*
grad_attn_weight
)
{
CUDA_KERNEL_LOOP
(
index
,
n
)
{
int
_temp
=
index
;
const
int
c_col
=
_temp
%
channels
;
_temp
/=
channels
;
const
int
sampling_index
=
_temp
;
const
int
m_col
=
_temp
%
num_heads
;
_temp
/=
num_heads
;
const
int
q_col
=
_temp
%
num_query
;
_temp
/=
num_query
;
const
int
b_col
=
_temp
;
const
scalar_t
top_grad
=
grad_col
[
index
];
int
data_weight_ptr
=
sampling_index
*
num_levels
*
num_point
;
int
data_loc_w_ptr
=
data_weight_ptr
<<
1
;
const
int
grad_sampling_ptr
=
data_weight_ptr
;
grad_sampling_loc
+=
grad_sampling_ptr
<<
1
;
grad_attn_weight
+=
grad_sampling_ptr
;
const
int
grad_weight_stride
=
1
;
const
int
grad_loc_stride
=
2
;
const
int
qid_stride
=
num_heads
*
channels
;
const
int
data_value_ptr_init_offset
=
b_col
*
spatial_size
*
qid_stride
;
for
(
int
l_col
=
0
;
l_col
<
num_levels
;
++
l_col
)
{
const
int
level_start_id
=
data_level_start_index
[
l_col
];
const
int
spatial_h_ptr
=
l_col
<<
1
;
const
int
spatial_h
=
data_spatial_shapes
[
spatial_h_ptr
];
const
int
spatial_w
=
data_spatial_shapes
[
spatial_h_ptr
+
1
];
const
int
value_ptr_offset
=
data_value_ptr_init_offset
+
level_start_id
*
qid_stride
;
const
scalar_t
*
data_value_ptr
=
data_value
+
value_ptr_offset
;
scalar_t
*
grad_value_ptr
=
grad_value
+
value_ptr_offset
;
for
(
int
p_col
=
0
;
p_col
<
num_point
;
++
p_col
)
{
const
scalar_t
loc_w
=
data_sampling_loc
[
data_loc_w_ptr
];
const
scalar_t
loc_h
=
data_sampling_loc
[
data_loc_w_ptr
+
1
];
const
scalar_t
weight
=
data_attn_weight
[
data_weight_ptr
];
const
scalar_t
h_im
=
loc_h
*
spatial_h
-
0.5
;
const
scalar_t
w_im
=
loc_w
*
spatial_w
-
0.5
;
if
(
h_im
>
-
1
&&
w_im
>
-
1
&&
h_im
<
spatial_h
&&
w_im
<
spatial_w
)
{
ms_deform_attn_col2im_bilinear_gm
(
data_value_ptr
,
spatial_h
,
spatial_w
,
num_heads
,
channels
,
h_im
,
w_im
,
m_col
,
c_col
,
top_grad
,
weight
,
grad_value_ptr
,
grad_sampling_loc
,
grad_attn_weight
);
}
data_weight_ptr
+=
1
;
data_loc_w_ptr
+=
2
;
grad_attn_weight
+=
grad_weight_stride
;
grad_sampling_loc
+=
grad_loc_stride
;
}
}
}
}
template
<
typename
scalar_t
>
void
ms_deformable_im2col_cuda
(
hipStream_t
stream
,
const
scalar_t
*
data_value
,
const
int64_t
*
data_spatial_shapes
,
const
int64_t
*
data_level_start_index
,
const
scalar_t
*
data_sampling_loc
,
const
scalar_t
*
data_attn_weight
,
const
int
batch_size
,
const
int
spatial_size
,
const
int
num_heads
,
const
int
channels
,
const
int
num_levels
,
const
int
num_query
,
const
int
num_point
,
scalar_t
*
data_col
)
{
const
int
num_kernels
=
batch_size
*
num_query
*
num_heads
*
channels
;
const
int
num_actual_kernels
=
batch_size
*
num_query
*
num_heads
*
channels
;
const
int
num_threads
=
CUDA_NUM_THREADS
;
hipLaunchKernelGGL
((
ms_deformable_im2col_gpu_kernel
<
scalar_t
>
)
,
dim3
(
GET_BLOCKS
(
num_actual_kernels
,
num_threads
)),
dim3
(
num_threads
),
0
,
stream
,
num_kernels
,
data_value
,
data_spatial_shapes
,
data_level_start_index
,
data_sampling_loc
,
data_attn_weight
,
batch_size
,
spatial_size
,
num_heads
,
channels
,
num_levels
,
num_query
,
num_point
,
data_col
);
hipError_t
err
=
hipGetLastError
();
if
(
err
!=
hipSuccess
)
{
printf
(
"error in ms_deformable_im2col_cuda: %s
\n
"
,
hipGetErrorString
(
err
));
}
}
template
<
typename
scalar_t
>
void
ms_deformable_col2im_cuda
(
hipStream_t
stream
,
const
scalar_t
*
grad_col
,
const
scalar_t
*
data_value
,
const
int64_t
*
data_spatial_shapes
,
const
int64_t
*
data_level_start_index
,
const
scalar_t
*
data_sampling_loc
,
const
scalar_t
*
data_attn_weight
,
const
int
batch_size
,
const
int
spatial_size
,
const
int
num_heads
,
const
int
channels
,
const
int
num_levels
,
const
int
num_query
,
const
int
num_point
,
scalar_t
*
grad_value
,
scalar_t
*
grad_sampling_loc
,
scalar_t
*
grad_attn_weight
)
{
const
int
num_threads
=
(
channels
>
CUDA_NUM_THREADS
)
?
CUDA_NUM_THREADS
:
channels
;
const
int
num_kernels
=
batch_size
*
num_query
*
num_heads
*
channels
;
const
int
num_actual_kernels
=
batch_size
*
num_query
*
num_heads
*
channels
;
if
(
channels
>
1024
)
{
if
((
channels
&
1023
)
==
0
)
{
hipLaunchKernelGGL
((
ms_deformable_col2im_gpu_kernel_shm_reduce_v2_multi_blocks
<
scalar_t
>
)
,
dim3
(
GET_BLOCKS
(
num_actual_kernels
,
num_threads
)),
dim3
(
num_threads
),
num_threads
*
3
*
sizeof
(
scalar_t
),
stream
,
num_kernels
,
grad_col
,
data_value
,
data_spatial_shapes
,
data_level_start_index
,
data_sampling_loc
,
data_attn_weight
,
batch_size
,
spatial_size
,
num_heads
,
channels
,
num_levels
,
num_query
,
num_point
,
grad_value
,
grad_sampling_loc
,
grad_attn_weight
);
}
else
{
hipLaunchKernelGGL
((
ms_deformable_col2im_gpu_kernel_gm
<
scalar_t
>
)
,
dim3
(
GET_BLOCKS
(
num_actual_kernels
,
num_threads
)),
dim3
(
num_threads
),
0
,
stream
,
num_kernels
,
grad_col
,
data_value
,
data_spatial_shapes
,
data_level_start_index
,
data_sampling_loc
,
data_attn_weight
,
batch_size
,
spatial_size
,
num_heads
,
channels
,
num_levels
,
num_query
,
num_point
,
grad_value
,
grad_sampling_loc
,
grad_attn_weight
);
}
}
else
{
switch
(
channels
)
{
case
1
:
hipLaunchKernelGGL
((
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1
<
scalar_t
,
1
>
)
,
dim3
(
GET_BLOCKS
(
num_actual_kernels
,
num_threads
)),
dim3
(
num_threads
),
0
,
stream
,
num_kernels
,
grad_col
,
data_value
,
data_spatial_shapes
,
data_level_start_index
,
data_sampling_loc
,
data_attn_weight
,
batch_size
,
spatial_size
,
num_heads
,
channels
,
num_levels
,
num_query
,
num_point
,
grad_value
,
grad_sampling_loc
,
grad_attn_weight
);
break
;
case
2
:
hipLaunchKernelGGL
((
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1
<
scalar_t
,
2
>
)
,
dim3
(
GET_BLOCKS
(
num_actual_kernels
,
num_threads
)),
dim3
(
num_threads
),
0
,
stream
,
num_kernels
,
grad_col
,
data_value
,
data_spatial_shapes
,
data_level_start_index
,
data_sampling_loc
,
data_attn_weight
,
batch_size
,
spatial_size
,
num_heads
,
channels
,
num_levels
,
num_query
,
num_point
,
grad_value
,
grad_sampling_loc
,
grad_attn_weight
);
break
;
case
4
:
hipLaunchKernelGGL
((
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1
<
scalar_t
,
4
>
)
,
dim3
(
GET_BLOCKS
(
num_actual_kernels
,
num_threads
)),
dim3
(
num_threads
),
0
,
stream
,
num_kernels
,
grad_col
,
data_value
,
data_spatial_shapes
,
data_level_start_index
,
data_sampling_loc
,
data_attn_weight
,
batch_size
,
spatial_size
,
num_heads
,
channels
,
num_levels
,
num_query
,
num_point
,
grad_value
,
grad_sampling_loc
,
grad_attn_weight
);
break
;
case
8
:
hipLaunchKernelGGL
((
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1
<
scalar_t
,
8
>
)
,
dim3
(
GET_BLOCKS
(
num_actual_kernels
,
num_threads
)),
dim3
(
num_threads
),
0
,
stream
,
num_kernels
,
grad_col
,
data_value
,
data_spatial_shapes
,
data_level_start_index
,
data_sampling_loc
,
data_attn_weight
,
batch_size
,
spatial_size
,
num_heads
,
channels
,
num_levels
,
num_query
,
num_point
,
grad_value
,
grad_sampling_loc
,
grad_attn_weight
);
break
;
case
16
:
hipLaunchKernelGGL
((
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1
<
scalar_t
,
16
>
)
,
dim3
(
GET_BLOCKS
(
num_actual_kernels
,
num_threads
)),
dim3
(
num_threads
),
0
,
stream
,
num_kernels
,
grad_col
,
data_value
,
data_spatial_shapes
,
data_level_start_index
,
data_sampling_loc
,
data_attn_weight
,
batch_size
,
spatial_size
,
num_heads
,
channels
,
num_levels
,
num_query
,
num_point
,
grad_value
,
grad_sampling_loc
,
grad_attn_weight
);
break
;
case
32
:
hipLaunchKernelGGL
((
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1
<
scalar_t
,
32
>
)
,
dim3
(
GET_BLOCKS
(
num_actual_kernels
,
num_threads
)),
dim3
(
num_threads
),
0
,
stream
,
num_kernels
,
grad_col
,
data_value
,
data_spatial_shapes
,
data_level_start_index
,
data_sampling_loc
,
data_attn_weight
,
batch_size
,
spatial_size
,
num_heads
,
channels
,
num_levels
,
num_query
,
num_point
,
grad_value
,
grad_sampling_loc
,
grad_attn_weight
);
break
;
case
64
:
hipLaunchKernelGGL
((
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2
<
scalar_t
,
64
>
)
,
dim3
(
GET_BLOCKS
(
num_actual_kernels
,
num_threads
)),
dim3
(
num_threads
),
0
,
stream
,
num_kernels
,
grad_col
,
data_value
,
data_spatial_shapes
,
data_level_start_index
,
data_sampling_loc
,
data_attn_weight
,
batch_size
,
spatial_size
,
num_heads
,
channels
,
num_levels
,
num_query
,
num_point
,
grad_value
,
grad_sampling_loc
,
grad_attn_weight
);
break
;
case
128
:
hipLaunchKernelGGL
((
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2
<
scalar_t
,
128
>
)
,
dim3
(
GET_BLOCKS
(
num_actual_kernels
,
num_threads
)),
dim3
(
num_threads
),
0
,
stream
,
num_kernels
,
grad_col
,
data_value
,
data_spatial_shapes
,
data_level_start_index
,
data_sampling_loc
,
data_attn_weight
,
batch_size
,
spatial_size
,
num_heads
,
channels
,
num_levels
,
num_query
,
num_point
,
grad_value
,
grad_sampling_loc
,
grad_attn_weight
);
break
;
case
256
:
hipLaunchKernelGGL
((
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2
<
scalar_t
,
256
>
)
,
dim3
(
GET_BLOCKS
(
num_actual_kernels
,
num_threads
)),
dim3
(
num_threads
),
0
,
stream
,
num_kernels
,
grad_col
,
data_value
,
data_spatial_shapes
,
data_level_start_index
,
data_sampling_loc
,
data_attn_weight
,
batch_size
,
spatial_size
,
num_heads
,
channels
,
num_levels
,
num_query
,
num_point
,
grad_value
,
grad_sampling_loc
,
grad_attn_weight
);
break
;
case
512
:
hipLaunchKernelGGL
((
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2
<
scalar_t
,
512
>
)
,
dim3
(
GET_BLOCKS
(
num_actual_kernels
,
num_threads
)),
dim3
(
num_threads
),
0
,
stream
,
num_kernels
,
grad_col
,
data_value
,
data_spatial_shapes
,
data_level_start_index
,
data_sampling_loc
,
data_attn_weight
,
batch_size
,
spatial_size
,
num_heads
,
channels
,
num_levels
,
num_query
,
num_point
,
grad_value
,
grad_sampling_loc
,
grad_attn_weight
);
break
;
case
1024
:
hipLaunchKernelGGL
((
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2
<
scalar_t
,
1024
>
)
,
dim3
(
GET_BLOCKS
(
num_actual_kernels
,
num_threads
)),
dim3
(
num_threads
),
0
,
stream
,
num_kernels
,
grad_col
,
data_value
,
data_spatial_shapes
,
data_level_start_index
,
data_sampling_loc
,
data_attn_weight
,
batch_size
,
spatial_size
,
num_heads
,
channels
,
num_levels
,
num_query
,
num_point
,
grad_value
,
grad_sampling_loc
,
grad_attn_weight
);
break
;
default:
if
(
channels
<
64
)
{
hipLaunchKernelGGL
((
ms_deformable_col2im_gpu_kernel_shm_reduce_v1
<
scalar_t
>
)
,
dim3
(
GET_BLOCKS
(
num_actual_kernels
,
num_threads
)),
dim3
(
num_threads
),
num_threads
*
3
*
sizeof
(
scalar_t
),
stream
,
num_kernels
,
grad_col
,
data_value
,
data_spatial_shapes
,
data_level_start_index
,
data_sampling_loc
,
data_attn_weight
,
batch_size
,
spatial_size
,
num_heads
,
channels
,
num_levels
,
num_query
,
num_point
,
grad_value
,
grad_sampling_loc
,
grad_attn_weight
);
}
else
{
hipLaunchKernelGGL
((
ms_deformable_col2im_gpu_kernel_shm_reduce_v2
<
scalar_t
>
)
,
dim3
(
GET_BLOCKS
(
num_actual_kernels
,
num_threads
)),
dim3
(
num_threads
),
num_threads
*
3
*
sizeof
(
scalar_t
),
stream
,
num_kernels
,
grad_col
,
data_value
,
data_spatial_shapes
,
data_level_start_index
,
data_sampling_loc
,
data_attn_weight
,
batch_size
,
spatial_size
,
num_heads
,
channels
,
num_levels
,
num_query
,
num_point
,
grad_value
,
grad_sampling_loc
,
grad_attn_weight
);
}
}
}
hipError_t
err
=
hipGetLastError
();
if
(
err
!=
hipSuccess
)
{
printf
(
"error in ms_deformable_col2im_cuda: %s
\n
"
,
hipGetErrorString
(
err
));
}
}
\ No newline at end of file
groundingdino/models/GroundingDINO/csrc/cuda_version.cu
0 → 100644
View file @
34e4011b
#include <cuda_runtime_api.h>
namespace
groundingdino
{
int
get_cudart_version
()
{
return
CUDART_VERSION
;
}
}
// namespace groundingdino
groundingdino/models/GroundingDINO/csrc/hip_version.hip
0 → 100644
View file @
34e4011b
// !!! This is a file automatically generated by hipify!!!
#include <ATen/dtk_macros.h>
#include <hip/hip_runtime_api.h>
namespace groundingdino {
int get_cudart_version() {
return DTKRT_VERSION;
}
} // namespace groundingdino
groundingdino/models/GroundingDINO/csrc/vision.cpp
0 → 100644
View file @
34e4011b
// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
#include "MsDeformAttn/ms_deform_attn.h"
namespace
groundingdino
{
#ifdef WITH_CUDA
extern
int
get_cudart_version
();
#endif
std
::
string
get_cuda_version
()
{
#ifdef WITH_CUDA
std
::
ostringstream
oss
;
// copied from
// https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/cuda/detail/CUDAHooks.cpp#L231
auto
printCudaStyleVersion
=
[
&
](
int
v
)
{
oss
<<
(
v
/
1000
)
<<
"."
<<
(
v
/
10
%
100
);
if
(
v
%
10
!=
0
)
{
oss
<<
"."
<<
(
v
%
10
);
}
};
printCudaStyleVersion
(
get_cudart_version
());
return
oss
.
str
();
#else
return
std
::
string
(
"not available"
);
#endif
}
// similar to
// https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Version.cpp
std
::
string
get_compiler_version
()
{
std
::
ostringstream
ss
;
#if defined(__GNUC__)
#ifndef __clang__
{
ss
<<
"GCC "
<<
__GNUC__
<<
"."
<<
__GNUC_MINOR__
;
}
#endif
#endif
#if defined(__clang_major__)
{
ss
<<
"clang "
<<
__clang_major__
<<
"."
<<
__clang_minor__
<<
"."
<<
__clang_patchlevel__
;
}
#endif
#if defined(_MSC_VER)
{
ss
<<
"MSVC "
<<
_MSC_FULL_VER
;
}
#endif
return
ss
.
str
();
}
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"ms_deform_attn_forward"
,
&
ms_deform_attn_forward
,
"ms_deform_attn_forward"
);
m
.
def
(
"ms_deform_attn_backward"
,
&
ms_deform_attn_backward
,
"ms_deform_attn_backward"
);
}
}
// namespace groundingdino
\ No newline at end of file
groundingdino/models/GroundingDINO/fuse_modules.py
0 → 100644
View file @
34e4011b
# ------------------------------------------------------------------------
# Grounding DINO
# url: https://github.com/IDEA-Research/GroundingDINO
# Copyright (c) 2023 IDEA. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
timm.models.layers
import
DropPath
class
FeatureResizer
(
nn
.
Module
):
"""
This class takes as input a set of embeddings of dimension C1 and outputs a set of
embedding of dimension C2, after a linear transformation, dropout and normalization (LN).
"""
def
__init__
(
self
,
input_feat_size
,
output_feat_size
,
dropout
,
do_ln
=
True
):
super
().
__init__
()
self
.
do_ln
=
do_ln
# Object feature encoding
self
.
fc
=
nn
.
Linear
(
input_feat_size
,
output_feat_size
,
bias
=
True
)
self
.
layer_norm
=
nn
.
LayerNorm
(
output_feat_size
,
eps
=
1e-12
)
self
.
dropout
=
nn
.
Dropout
(
dropout
)
def
forward
(
self
,
encoder_features
):
x
=
self
.
fc
(
encoder_features
)
if
self
.
do_ln
:
x
=
self
.
layer_norm
(
x
)
output
=
self
.
dropout
(
x
)
return
output
def
l1norm
(
X
,
dim
,
eps
=
1e-8
):
"""L1-normalize columns of X"""
norm
=
torch
.
abs
(
X
).
sum
(
dim
=
dim
,
keepdim
=
True
)
+
eps
X
=
torch
.
div
(
X
,
norm
)
return
X
def
l2norm
(
X
,
dim
,
eps
=
1e-8
):
"""L2-normalize columns of X"""
norm
=
torch
.
pow
(
X
,
2
).
sum
(
dim
=
dim
,
keepdim
=
True
).
sqrt
()
+
eps
X
=
torch
.
div
(
X
,
norm
)
return
X
def
func_attention
(
query
,
context
,
smooth
=
1
,
raw_feature_norm
=
"softmax"
,
eps
=
1e-8
):
"""
query: (n_context, queryL, d)
context: (n_context, sourceL, d)
"""
batch_size_q
,
queryL
=
query
.
size
(
0
),
query
.
size
(
1
)
batch_size
,
sourceL
=
context
.
size
(
0
),
context
.
size
(
1
)
# Get attention
# --> (batch, d, queryL)
queryT
=
torch
.
transpose
(
query
,
1
,
2
)
# (batch, sourceL, d)(batch, d, queryL)
# --> (batch, sourceL, queryL)
attn
=
torch
.
bmm
(
context
,
queryT
)
if
raw_feature_norm
==
"softmax"
:
# --> (batch*sourceL, queryL)
attn
=
attn
.
view
(
batch_size
*
sourceL
,
queryL
)
attn
=
nn
.
Softmax
()(
attn
)
# --> (batch, sourceL, queryL)
attn
=
attn
.
view
(
batch_size
,
sourceL
,
queryL
)
elif
raw_feature_norm
==
"l2norm"
:
attn
=
l2norm
(
attn
,
2
)
elif
raw_feature_norm
==
"clipped_l2norm"
:
attn
=
nn
.
LeakyReLU
(
0.1
)(
attn
)
attn
=
l2norm
(
attn
,
2
)
else
:
raise
ValueError
(
"unknown first norm type:"
,
raw_feature_norm
)
# --> (batch, queryL, sourceL)
attn
=
torch
.
transpose
(
attn
,
1
,
2
).
contiguous
()
# --> (batch*queryL, sourceL)
attn
=
attn
.
view
(
batch_size
*
queryL
,
sourceL
)
attn
=
nn
.
Softmax
()(
attn
*
smooth
)
# --> (batch, queryL, sourceL)
attn
=
attn
.
view
(
batch_size
,
queryL
,
sourceL
)
# --> (batch, sourceL, queryL)
attnT
=
torch
.
transpose
(
attn
,
1
,
2
).
contiguous
()
# --> (batch, d, sourceL)
contextT
=
torch
.
transpose
(
context
,
1
,
2
)
# (batch x d x sourceL)(batch x sourceL x queryL)
# --> (batch, d, queryL)
weightedContext
=
torch
.
bmm
(
contextT
,
attnT
)
# --> (batch, queryL, d)
weightedContext
=
torch
.
transpose
(
weightedContext
,
1
,
2
)
return
weightedContext
,
attnT
class
BiMultiHeadAttention
(
nn
.
Module
):
def
__init__
(
self
,
v_dim
,
l_dim
,
embed_dim
,
num_heads
,
dropout
=
0.1
,
cfg
=
None
):
super
(
BiMultiHeadAttention
,
self
).
__init__
()
self
.
embed_dim
=
embed_dim
self
.
num_heads
=
num_heads
self
.
head_dim
=
embed_dim
//
num_heads
self
.
v_dim
=
v_dim
self
.
l_dim
=
l_dim
assert
(
self
.
head_dim
*
self
.
num_heads
==
self
.
embed_dim
),
f
"embed_dim must be divisible by num_heads (got `embed_dim`:
{
self
.
embed_dim
}
and `num_heads`:
{
self
.
num_heads
}
)."
self
.
scale
=
self
.
head_dim
**
(
-
0.5
)
self
.
dropout
=
dropout
self
.
v_proj
=
nn
.
Linear
(
self
.
v_dim
,
self
.
embed_dim
)
self
.
l_proj
=
nn
.
Linear
(
self
.
l_dim
,
self
.
embed_dim
)
self
.
values_v_proj
=
nn
.
Linear
(
self
.
v_dim
,
self
.
embed_dim
)
self
.
values_l_proj
=
nn
.
Linear
(
self
.
l_dim
,
self
.
embed_dim
)
self
.
out_v_proj
=
nn
.
Linear
(
self
.
embed_dim
,
self
.
v_dim
)
self
.
out_l_proj
=
nn
.
Linear
(
self
.
embed_dim
,
self
.
l_dim
)
self
.
stable_softmax_2d
=
True
self
.
clamp_min_for_underflow
=
True
self
.
clamp_max_for_overflow
=
True
self
.
_reset_parameters
()
def
_shape
(
self
,
tensor
:
torch
.
Tensor
,
seq_len
:
int
,
bsz
:
int
):
return
tensor
.
view
(
bsz
,
seq_len
,
self
.
num_heads
,
self
.
head_dim
).
transpose
(
1
,
2
).
contiguous
()
def
_reset_parameters
(
self
):
nn
.
init
.
xavier_uniform_
(
self
.
v_proj
.
weight
)
self
.
v_proj
.
bias
.
data
.
fill_
(
0
)
nn
.
init
.
xavier_uniform_
(
self
.
l_proj
.
weight
)
self
.
l_proj
.
bias
.
data
.
fill_
(
0
)
nn
.
init
.
xavier_uniform_
(
self
.
values_v_proj
.
weight
)
self
.
values_v_proj
.
bias
.
data
.
fill_
(
0
)
nn
.
init
.
xavier_uniform_
(
self
.
values_l_proj
.
weight
)
self
.
values_l_proj
.
bias
.
data
.
fill_
(
0
)
nn
.
init
.
xavier_uniform_
(
self
.
out_v_proj
.
weight
)
self
.
out_v_proj
.
bias
.
data
.
fill_
(
0
)
nn
.
init
.
xavier_uniform_
(
self
.
out_l_proj
.
weight
)
self
.
out_l_proj
.
bias
.
data
.
fill_
(
0
)
def
forward
(
self
,
v
,
l
,
attention_mask_v
=
None
,
attention_mask_l
=
None
):
"""_summary_
Args:
v (_type_): bs, n_img, dim
l (_type_): bs, n_text, dim
attention_mask_v (_type_, optional): _description_. bs, n_img
attention_mask_l (_type_, optional): _description_. bs, n_text
Returns:
_type_: _description_
"""
# if os.environ.get('IPDB_SHILONG_DEBUG', None) == 'INFO':
# import ipdb; ipdb.set_trace()
bsz
,
tgt_len
,
_
=
v
.
size
()
query_states
=
self
.
v_proj
(
v
)
*
self
.
scale
key_states
=
self
.
_shape
(
self
.
l_proj
(
l
),
-
1
,
bsz
)
value_v_states
=
self
.
_shape
(
self
.
values_v_proj
(
v
),
-
1
,
bsz
)
value_l_states
=
self
.
_shape
(
self
.
values_l_proj
(
l
),
-
1
,
bsz
)
proj_shape
=
(
bsz
*
self
.
num_heads
,
-
1
,
self
.
head_dim
)
query_states
=
self
.
_shape
(
query_states
,
tgt_len
,
bsz
).
view
(
*
proj_shape
)
key_states
=
key_states
.
view
(
*
proj_shape
)
value_v_states
=
value_v_states
.
view
(
*
proj_shape
)
value_l_states
=
value_l_states
.
view
(
*
proj_shape
)
src_len
=
key_states
.
size
(
1
)
attn_weights
=
torch
.
bmm
(
query_states
,
key_states
.
transpose
(
1
,
2
))
# bs*nhead, nimg, ntxt
if
attn_weights
.
size
()
!=
(
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
):
raise
ValueError
(
f
"Attention weights should be of size
{
(
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
)
}
, but is
{
attn_weights
.
size
()
}
"
)
if
self
.
stable_softmax_2d
:
attn_weights
=
attn_weights
-
attn_weights
.
max
()
if
self
.
clamp_min_for_underflow
:
attn_weights
=
torch
.
clamp
(
attn_weights
,
min
=-
50000
)
# Do not increase -50000, data type half has quite limited range
if
self
.
clamp_max_for_overflow
:
attn_weights
=
torch
.
clamp
(
attn_weights
,
max
=
50000
)
# Do not increase 50000, data type half has quite limited range
attn_weights_T
=
attn_weights
.
transpose
(
1
,
2
)
attn_weights_l
=
attn_weights_T
-
torch
.
max
(
attn_weights_T
,
dim
=-
1
,
keepdim
=
True
)[
0
]
if
self
.
clamp_min_for_underflow
:
attn_weights_l
=
torch
.
clamp
(
attn_weights_l
,
min
=-
50000
)
# Do not increase -50000, data type half has quite limited range
if
self
.
clamp_max_for_overflow
:
attn_weights_l
=
torch
.
clamp
(
attn_weights_l
,
max
=
50000
)
# Do not increase 50000, data type half has quite limited range
# mask vison for language
if
attention_mask_v
is
not
None
:
attention_mask_v
=
(
attention_mask_v
[:,
None
,
None
,
:].
repeat
(
1
,
self
.
num_heads
,
1
,
1
).
flatten
(
0
,
1
)
)
attn_weights_l
.
masked_fill_
(
attention_mask_v
,
float
(
"-inf"
))
attn_weights_l
=
attn_weights_l
.
softmax
(
dim
=-
1
)
# mask language for vision
if
attention_mask_l
is
not
None
:
attention_mask_l
=
(
attention_mask_l
[:,
None
,
None
,
:].
repeat
(
1
,
self
.
num_heads
,
1
,
1
).
flatten
(
0
,
1
)
)
attn_weights
.
masked_fill_
(
attention_mask_l
,
float
(
"-inf"
))
attn_weights_v
=
attn_weights
.
softmax
(
dim
=-
1
)
attn_probs_v
=
F
.
dropout
(
attn_weights_v
,
p
=
self
.
dropout
,
training
=
self
.
training
)
attn_probs_l
=
F
.
dropout
(
attn_weights_l
,
p
=
self
.
dropout
,
training
=
self
.
training
)
attn_output_v
=
torch
.
bmm
(
attn_probs_v
,
value_l_states
)
attn_output_l
=
torch
.
bmm
(
attn_probs_l
,
value_v_states
)
if
attn_output_v
.
size
()
!=
(
bsz
*
self
.
num_heads
,
tgt_len
,
self
.
head_dim
):
raise
ValueError
(
f
"`attn_output_v` should be of size
{
(
bsz
,
self
.
num_heads
,
tgt_len
,
self
.
head_dim
)
}
, but is
{
attn_output_v
.
size
()
}
"
)
if
attn_output_l
.
size
()
!=
(
bsz
*
self
.
num_heads
,
src_len
,
self
.
head_dim
):
raise
ValueError
(
f
"`attn_output_l` should be of size
{
(
bsz
,
self
.
num_heads
,
src_len
,
self
.
head_dim
)
}
, but is
{
attn_output_l
.
size
()
}
"
)
attn_output_v
=
attn_output_v
.
view
(
bsz
,
self
.
num_heads
,
tgt_len
,
self
.
head_dim
)
attn_output_v
=
attn_output_v
.
transpose
(
1
,
2
)
attn_output_v
=
attn_output_v
.
reshape
(
bsz
,
tgt_len
,
self
.
embed_dim
)
attn_output_l
=
attn_output_l
.
view
(
bsz
,
self
.
num_heads
,
src_len
,
self
.
head_dim
)
attn_output_l
=
attn_output_l
.
transpose
(
1
,
2
)
attn_output_l
=
attn_output_l
.
reshape
(
bsz
,
src_len
,
self
.
embed_dim
)
attn_output_v
=
self
.
out_v_proj
(
attn_output_v
)
attn_output_l
=
self
.
out_l_proj
(
attn_output_l
)
return
attn_output_v
,
attn_output_l
# Bi-Direction MHA (text->image, image->text)
class
BiAttentionBlock
(
nn
.
Module
):
def
__init__
(
self
,
v_dim
,
l_dim
,
embed_dim
,
num_heads
,
dropout
=
0.1
,
drop_path
=
0.0
,
init_values
=
1e-4
,
cfg
=
None
,
):
"""
Inputs:
embed_dim - Dimensionality of input and attention feature vectors
hidden_dim - Dimensionality of hidden layer in feed-forward network
(usually 2-4x larger than embed_dim)
num_heads - Number of heads to use in the Multi-Head Attention block
dropout - Amount of dropout to apply in the feed-forward network
"""
super
(
BiAttentionBlock
,
self
).
__init__
()
# pre layer norm
self
.
layer_norm_v
=
nn
.
LayerNorm
(
v_dim
)
self
.
layer_norm_l
=
nn
.
LayerNorm
(
l_dim
)
self
.
attn
=
BiMultiHeadAttention
(
v_dim
=
v_dim
,
l_dim
=
l_dim
,
embed_dim
=
embed_dim
,
num_heads
=
num_heads
,
dropout
=
dropout
)
# add layer scale for training stability
self
.
drop_path
=
DropPath
(
drop_path
)
if
drop_path
>
0.0
else
nn
.
Identity
()
self
.
gamma_v
=
nn
.
Parameter
(
init_values
*
torch
.
ones
((
v_dim
)),
requires_grad
=
True
)
self
.
gamma_l
=
nn
.
Parameter
(
init_values
*
torch
.
ones
((
l_dim
)),
requires_grad
=
True
)
def
forward
(
self
,
v
,
l
,
attention_mask_v
=
None
,
attention_mask_l
=
None
):
v
=
self
.
layer_norm_v
(
v
)
l
=
self
.
layer_norm_l
(
l
)
delta_v
,
delta_l
=
self
.
attn
(
v
,
l
,
attention_mask_v
=
attention_mask_v
,
attention_mask_l
=
attention_mask_l
)
# v, l = v + delta_v, l + delta_l
v
=
v
+
self
.
drop_path
(
self
.
gamma_v
*
delta_v
)
l
=
l
+
self
.
drop_path
(
self
.
gamma_l
*
delta_l
)
return
v
,
l
# def forward(self, v:List[torch.Tensor], l, attention_mask_v=None, attention_mask_l=None)
groundingdino/models/GroundingDINO/groundingdino.py
0 → 100644
View file @
34e4011b
# ------------------------------------------------------------------------
# Grounding DINO
# url: https://github.com/IDEA-Research/GroundingDINO
# Copyright (c) 2023 IDEA. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------
# Conditional DETR model and criterion classes.
# Copyright (c) 2021 Microsoft. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------
# Modified from DETR (https://github.com/facebookresearch/detr)
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
# ------------------------------------------------------------------------
# Modified from Deformable DETR (https://github.com/fundamentalvision/Deformable-DETR)
# Copyright (c) 2020 SenseTime. All Rights Reserved.
# ------------------------------------------------------------------------
import
copy
from
typing
import
List
,
Optional
import
torch
import
torch.nn.functional
as
F
from
torch
import
nn
from
torchvision.ops.boxes
import
nms
from
transformers
import
AutoTokenizer
,
BertModel
,
BertTokenizer
,
RobertaModel
,
RobertaTokenizerFast
from
groundingdino.util
import
box_ops
,
get_tokenlizer
from
groundingdino.util.misc
import
(
NestedTensor
,
accuracy
,
get_world_size
,
interpolate
,
inverse_sigmoid
,
is_dist_avail_and_initialized
,
nested_tensor_from_tensor_list
,
)
from
groundingdino.util.utils
import
get_phrases_from_posmap
from
groundingdino.util.visualizer
import
COCOVisualizer
from
groundingdino.util.vl_utils
import
create_positive_map_from_span
from
..registry
import
MODULE_BUILD_FUNCS
from
.backbone
import
build_backbone
from
.bertwarper
import
(
BertModelWarper
,
generate_masks_with_special_tokens
,
generate_masks_with_special_tokens_and_transfer_map
,
)
from
.transformer
import
build_transformer
from
.utils
import
MLP
,
ContrastiveEmbed
,
sigmoid_focal_loss
class
GroundingDINO
(
nn
.
Module
):
"""This is the Cross-Attention Detector module that performs object detection"""
def
__init__
(
self
,
backbone
,
transformer
,
num_queries
,
aux_loss
=
False
,
iter_update
=
False
,
query_dim
=
2
,
num_feature_levels
=
1
,
nheads
=
8
,
# two stage
two_stage_type
=
"no"
,
# ['no', 'standard']
dec_pred_bbox_embed_share
=
True
,
two_stage_class_embed_share
=
True
,
two_stage_bbox_embed_share
=
True
,
num_patterns
=
0
,
dn_number
=
100
,
dn_box_noise_scale
=
0.4
,
dn_label_noise_ratio
=
0.5
,
dn_labelbook_size
=
100
,
text_encoder_type
=
"bert-base-uncased"
,
sub_sentence_present
=
True
,
max_text_len
=
256
,
):
"""Initializes the model.
Parameters:
backbone: torch module of the backbone to be used. See backbone.py
transformer: torch module of the transformer architecture. See transformer.py
num_queries: number of object queries, ie detection slot. This is the maximal number of objects
Conditional DETR can detect in a single image. For COCO, we recommend 100 queries.
aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used.
"""
super
().
__init__
()
self
.
num_queries
=
num_queries
self
.
transformer
=
transformer
self
.
hidden_dim
=
hidden_dim
=
transformer
.
d_model
self
.
num_feature_levels
=
num_feature_levels
self
.
nheads
=
nheads
self
.
max_text_len
=
256
self
.
sub_sentence_present
=
sub_sentence_present
# setting query dim
self
.
query_dim
=
query_dim
assert
query_dim
==
4
# for dn training
self
.
num_patterns
=
num_patterns
self
.
dn_number
=
dn_number
self
.
dn_box_noise_scale
=
dn_box_noise_scale
self
.
dn_label_noise_ratio
=
dn_label_noise_ratio
self
.
dn_labelbook_size
=
dn_labelbook_size
# bert
self
.
tokenizer
=
get_tokenlizer
.
get_tokenlizer
(
text_encoder_type
)
self
.
bert
=
get_tokenlizer
.
get_pretrained_language_model
(
text_encoder_type
)
self
.
bert
.
pooler
.
dense
.
weight
.
requires_grad_
(
False
)
self
.
bert
.
pooler
.
dense
.
bias
.
requires_grad_
(
False
)
self
.
bert
=
BertModelWarper
(
bert_model
=
self
.
bert
)
self
.
feat_map
=
nn
.
Linear
(
self
.
bert
.
config
.
hidden_size
,
self
.
hidden_dim
,
bias
=
True
)
nn
.
init
.
constant_
(
self
.
feat_map
.
bias
.
data
,
0
)
nn
.
init
.
xavier_uniform_
(
self
.
feat_map
.
weight
.
data
)
# freeze
# special tokens
self
.
specical_tokens
=
self
.
tokenizer
.
convert_tokens_to_ids
([
"[CLS]"
,
"[SEP]"
,
"."
,
"?"
])
# prepare input projection layers
if
num_feature_levels
>
1
:
num_backbone_outs
=
len
(
backbone
.
num_channels
)
input_proj_list
=
[]
for
_
in
range
(
num_backbone_outs
):
in_channels
=
backbone
.
num_channels
[
_
]
input_proj_list
.
append
(
nn
.
Sequential
(
nn
.
Conv2d
(
in_channels
,
hidden_dim
,
kernel_size
=
1
),
nn
.
GroupNorm
(
32
,
hidden_dim
),
)
)
for
_
in
range
(
num_feature_levels
-
num_backbone_outs
):
input_proj_list
.
append
(
nn
.
Sequential
(
nn
.
Conv2d
(
in_channels
,
hidden_dim
,
kernel_size
=
3
,
stride
=
2
,
padding
=
1
),
nn
.
GroupNorm
(
32
,
hidden_dim
),
)
)
in_channels
=
hidden_dim
self
.
input_proj
=
nn
.
ModuleList
(
input_proj_list
)
else
:
assert
two_stage_type
==
"no"
,
"two_stage_type should be no if num_feature_levels=1 !!!"
self
.
input_proj
=
nn
.
ModuleList
(
[
nn
.
Sequential
(
nn
.
Conv2d
(
backbone
.
num_channels
[
-
1
],
hidden_dim
,
kernel_size
=
1
),
nn
.
GroupNorm
(
32
,
hidden_dim
),
)
]
)
self
.
backbone
=
backbone
self
.
aux_loss
=
aux_loss
self
.
box_pred_damping
=
box_pred_damping
=
None
self
.
iter_update
=
iter_update
assert
iter_update
,
"Why not iter_update?"
# prepare pred layers
self
.
dec_pred_bbox_embed_share
=
dec_pred_bbox_embed_share
# prepare class & box embed
_class_embed
=
ContrastiveEmbed
()
_bbox_embed
=
MLP
(
hidden_dim
,
hidden_dim
,
4
,
3
)
nn
.
init
.
constant_
(
_bbox_embed
.
layers
[
-
1
].
weight
.
data
,
0
)
nn
.
init
.
constant_
(
_bbox_embed
.
layers
[
-
1
].
bias
.
data
,
0
)
if
dec_pred_bbox_embed_share
:
box_embed_layerlist
=
[
_bbox_embed
for
i
in
range
(
transformer
.
num_decoder_layers
)]
else
:
box_embed_layerlist
=
[
copy
.
deepcopy
(
_bbox_embed
)
for
i
in
range
(
transformer
.
num_decoder_layers
)
]
class_embed_layerlist
=
[
_class_embed
for
i
in
range
(
transformer
.
num_decoder_layers
)]
self
.
bbox_embed
=
nn
.
ModuleList
(
box_embed_layerlist
)
self
.
class_embed
=
nn
.
ModuleList
(
class_embed_layerlist
)
self
.
transformer
.
decoder
.
bbox_embed
=
self
.
bbox_embed
self
.
transformer
.
decoder
.
class_embed
=
self
.
class_embed
# two stage
self
.
two_stage_type
=
two_stage_type
assert
two_stage_type
in
[
"no"
,
"standard"
],
"unknown param {} of two_stage_type"
.
format
(
two_stage_type
)
if
two_stage_type
!=
"no"
:
if
two_stage_bbox_embed_share
:
assert
dec_pred_bbox_embed_share
self
.
transformer
.
enc_out_bbox_embed
=
_bbox_embed
else
:
self
.
transformer
.
enc_out_bbox_embed
=
copy
.
deepcopy
(
_bbox_embed
)
if
two_stage_class_embed_share
:
assert
dec_pred_bbox_embed_share
self
.
transformer
.
enc_out_class_embed
=
_class_embed
else
:
self
.
transformer
.
enc_out_class_embed
=
copy
.
deepcopy
(
_class_embed
)
self
.
refpoint_embed
=
None
self
.
_reset_parameters
()
def
_reset_parameters
(
self
):
# init input_proj
for
proj
in
self
.
input_proj
:
nn
.
init
.
xavier_uniform_
(
proj
[
0
].
weight
,
gain
=
1
)
nn
.
init
.
constant_
(
proj
[
0
].
bias
,
0
)
def
set_image_tensor
(
self
,
samples
:
NestedTensor
):
if
isinstance
(
samples
,
(
list
,
torch
.
Tensor
)):
samples
=
nested_tensor_from_tensor_list
(
samples
)
self
.
features
,
self
.
poss
=
self
.
backbone
(
samples
)
def
unset_image_tensor
(
self
):
if
hasattr
(
self
,
'features'
):
del
self
.
features
if
hasattr
(
self
,
'poss'
):
del
self
.
poss
def
set_image_features
(
self
,
features
,
poss
):
self
.
features
=
features
self
.
poss
=
poss
def
init_ref_points
(
self
,
use_num_queries
):
self
.
refpoint_embed
=
nn
.
Embedding
(
use_num_queries
,
self
.
query_dim
)
# def forward(self, samples: NestedTensor, targets: List = None, **kw):
def
forward
(
self
,
samples
:
NestedTensor
,
input_ids
:
Optional
[
torch
.
Tensor
]
=
None
,
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
position_ids
:
Optional
[
torch
.
Tensor
]
=
None
,
token_type_ids
:
Optional
[
torch
.
Tensor
]
=
None
,
text_self_attention_masks
:
Optional
[
torch
.
Tensor
]
=
None
):
"""The forward expects a NestedTensor, which consists of:
- samples.tensor: batched images, of shape [batch_size x 3 x H x W]
- samples.mask: a binary mask of shape [batch_size x H x W], containing 1 on padded pixels
It returns a dict with the following elements:
- "pred_logits": the classification logits (including no-object) for all queries.
Shape= [batch_size x num_queries x num_classes]
- "pred_boxes": The normalized boxes coordinates for all queries, represented as
(center_x, center_y, width, height). These values are normalized in [0, 1],
relative to the size of each individual image (disregarding possible padding).
See PostProcess for information on how to retrieve the unnormalized bounding box.
- "aux_outputs": Optional, only returned when auxilary losses are activated. It is a list of
dictionnaries containing the two above keys for each decoder layer.
"""
# if targets is None:
# captions = kw["captions"]
# else:
# captions = [t["caption"] for t in targets]
# # encoder texts
# tokenized = self.tokenizer(captions, padding="longest", return_tensors="pt").to(
# samples.device
# )
# (
# text_self_attention_masks,
# position_ids,
# cate_to_token_mask_list,
# ) = generate_masks_with_special_tokens_and_transfer_map(
# tokenized, self.specical_tokens, self.tokenizer
# )
# if text_self_attention_masks.shape[1] > self.max_text_len:
# text_self_attention_masks = text_self_attention_masks[
# :, : self.max_text_len, : self.max_text_len
# ]
# position_ids = position_ids[:, : self.max_text_len]
# tokenized["input_ids"] = tokenized["input_ids"][:, : self.max_text_len]
# tokenized["attention_mask"] = tokenized["attention_mask"][:, : self.max_text_len]
# tokenized["token_type_ids"] = tokenized["token_type_ids"][:, : self.max_text_len]
# # extract text embeddings
# if self.sub_sentence_present:
# tokenized_for_encoder = {k: v for k, v in tokenized.items() if k != "attention_mask"}
# tokenized_for_encoder["attention_mask"] = text_self_attention_masks
# tokenized_for_encoder["position_ids"] = position_ids
# else:
# # import ipdb; ipdb.set_trace()
# tokenized_for_encoder = tokenized
# ---------------------------输入不同--------------------------
tokenized_for_encoder
=
{}
tokenized_for_encoder
[
"input_ids"
]
=
input_ids
tokenized_for_encoder
[
"attention_mask"
]
=
text_self_attention_masks
tokenized_for_encoder
[
"position_ids"
]
=
position_ids
tokenized_for_encoder
[
"token_type_ids"
]
=
token_type_ids
# --------------------------输入不同结束------------------------
bert_output
=
self
.
bert
(
**
tokenized_for_encoder
)
# bs, 195, 768
encoded_text
=
self
.
feat_map
(
bert_output
[
"last_hidden_state"
])
# bs, 195, d_model
#------------------------------2---------------------------------
# text_token_mask = tokenized.attention_mask.bool() # bs, 195
#-----------------------------------------------------------------
text_token_mask
=
attention_mask
.
bool
()
# bs, 195
# text_token_mask: True for nomask, False for mask
# text_self_attention_masks: True for nomask, False for mask
if
encoded_text
.
shape
[
1
]
>
self
.
max_text_len
:
encoded_text
=
encoded_text
[:,
:
self
.
max_text_len
,
:]
text_token_mask
=
text_token_mask
[:,
:
self
.
max_text_len
]
position_ids
=
position_ids
[:,
:
self
.
max_text_len
]
text_self_attention_masks
=
text_self_attention_masks
[
:,
:
self
.
max_text_len
,
:
self
.
max_text_len
]
text_dict
=
{
"encoded_text"
:
encoded_text
,
# bs, 195, d_model
"text_token_mask"
:
text_token_mask
,
# bs, 195
"position_ids"
:
position_ids
,
# bs, 195
"text_self_attention_masks"
:
text_self_attention_masks
,
# bs, 195,195
}
# import ipdb; ipdb.set_trace()
if
isinstance
(
samples
,
(
list
,
torch
.
Tensor
)):
samples
=
nested_tensor_from_tensor_list
(
samples
)
if
not
hasattr
(
self
,
'features'
)
or
not
hasattr
(
self
,
'poss'
):
self
.
set_image_tensor
(
samples
)
srcs
=
[]
masks
=
[]
for
l
,
feat
in
enumerate
(
self
.
features
):
src
,
mask
=
feat
.
decompose
()
srcs
.
append
(
self
.
input_proj
[
l
](
src
))
masks
.
append
(
mask
)
assert
mask
is
not
None
if
self
.
num_feature_levels
>
len
(
srcs
):
_len_srcs
=
len
(
srcs
)
for
l
in
range
(
_len_srcs
,
self
.
num_feature_levels
):
if
l
==
_len_srcs
:
src
=
self
.
input_proj
[
l
](
self
.
features
[
-
1
].
tensors
)
else
:
src
=
self
.
input_proj
[
l
](
srcs
[
-
1
])
m
=
samples
.
mask
mask
=
F
.
interpolate
(
m
[
None
].
float
(),
size
=
src
.
shape
[
-
2
:]).
to
(
torch
.
bool
)[
0
]
pos_l
=
self
.
backbone
[
1
](
NestedTensor
(
src
,
mask
)).
to
(
src
.
dtype
)
srcs
.
append
(
src
)
masks
.
append
(
mask
)
self
.
poss
.
append
(
pos_l
)
input_query_bbox
=
input_query_label
=
attn_mask
=
dn_meta
=
None
hs
,
reference
,
hs_enc
,
ref_enc
,
init_box_proposal
=
self
.
transformer
(
srcs
,
masks
,
input_query_bbox
,
self
.
poss
,
input_query_label
,
attn_mask
,
text_dict
)
# deformable-detr-like anchor update
outputs_coord_list
=
[]
for
dec_lid
,
(
layer_ref_sig
,
layer_bbox_embed
,
layer_hs
)
in
enumerate
(
zip
(
reference
[:
-
1
],
self
.
bbox_embed
,
hs
)
):
layer_delta_unsig
=
layer_bbox_embed
(
layer_hs
)
layer_outputs_unsig
=
layer_delta_unsig
+
inverse_sigmoid
(
layer_ref_sig
)
layer_outputs_unsig
=
layer_outputs_unsig
.
sigmoid
()
outputs_coord_list
.
append
(
layer_outputs_unsig
)
outputs_coord_list
=
torch
.
stack
(
outputs_coord_list
)
# output
outputs_class
=
torch
.
stack
(
[
layer_cls_embed
(
layer_hs
,
text_dict
)
for
layer_cls_embed
,
layer_hs
in
zip
(
self
.
class_embed
,
hs
)
]
)
out
=
{
"pred_logits"
:
outputs_class
[
-
1
],
"pred_boxes"
:
outputs_coord_list
[
-
1
]}
# # for intermediate outputs
# if self.aux_loss:
# out['aux_outputs'] = self._set_aux_loss(outputs_class, outputs_coord_list)
# # for encoder output
# if hs_enc is not None:
# # prepare intermediate outputs
# interm_coord = ref_enc[-1]
# interm_class = self.transformer.enc_out_class_embed(hs_enc[-1], text_dict)
# out['interm_outputs'] = {'pred_logits': interm_class, 'pred_boxes': interm_coord}
# out['interm_outputs_for_matching_pre'] = {'pred_logits': interm_class, 'pred_boxes': init_box_proposal}
#-------------------------2---------------------------------
# unset_image_tensor = kw.get('unset_image_tensor', True)
# if unset_image_tensor:
# self.unset_image_tensor() ## If necessary
#-----------------------------------------------------------
return
out
@
torch
.
jit
.
unused
def
_set_aux_loss
(
self
,
outputs_class
,
outputs_coord
):
# this is a workaround to make torchscript happy, as torchscript
# doesn't support dictionary with non-homogeneous values, such
# as a dict having both a Tensor and a list.
return
[
{
"pred_logits"
:
a
,
"pred_boxes"
:
b
}
for
a
,
b
in
zip
(
outputs_class
[:
-
1
],
outputs_coord
[:
-
1
])
]
@
MODULE_BUILD_FUNCS
.
registe_with_name
(
module_name
=
"groundingdino"
)
def
build_groundingdino
(
args
):
backbone
=
build_backbone
(
args
)
transformer
=
build_transformer
(
args
)
dn_labelbook_size
=
args
.
dn_labelbook_size
dec_pred_bbox_embed_share
=
args
.
dec_pred_bbox_embed_share
sub_sentence_present
=
args
.
sub_sentence_present
model
=
GroundingDINO
(
backbone
,
transformer
,
num_queries
=
args
.
num_queries
,
aux_loss
=
True
,
iter_update
=
True
,
query_dim
=
4
,
num_feature_levels
=
args
.
num_feature_levels
,
nheads
=
args
.
nheads
,
dec_pred_bbox_embed_share
=
dec_pred_bbox_embed_share
,
two_stage_type
=
args
.
two_stage_type
,
two_stage_bbox_embed_share
=
args
.
two_stage_bbox_embed_share
,
two_stage_class_embed_share
=
args
.
two_stage_class_embed_share
,
num_patterns
=
args
.
num_patterns
,
dn_number
=
0
,
dn_box_noise_scale
=
args
.
dn_box_noise_scale
,
dn_label_noise_ratio
=
args
.
dn_label_noise_ratio
,
dn_labelbook_size
=
dn_labelbook_size
,
text_encoder_type
=
args
.
text_encoder_type
,
sub_sentence_present
=
sub_sentence_present
,
max_text_len
=
args
.
max_text_len
,
)
return
model
groundingdino/models/GroundingDINO/groundingdino_onnx.py
0 → 100644
View file @
34e4011b
# ------------------------------------------------------------------------
# Grounding DINO
# url: https://github.com/IDEA-Research/GroundingDINO
# Copyright (c) 2023 IDEA. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------
# Conditional DETR model and criterion classes.
# Copyright (c) 2021 Microsoft. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------
# Modified from DETR (https://github.com/facebookresearch/detr)
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
# ------------------------------------------------------------------------
# Modified from Deformable DETR (https://github.com/fundamentalvision/Deformable-DETR)
# Copyright (c) 2020 SenseTime. All Rights Reserved.
# ------------------------------------------------------------------------
import
copy
from
typing
import
List
,
Optional
import
torch
import
torch.nn.functional
as
F
from
torch
import
nn
from
torchvision.ops.boxes
import
nms
from
transformers
import
AutoTokenizer
,
BertModel
,
BertTokenizer
,
RobertaModel
,
RobertaTokenizerFast
from
groundingdino.util
import
box_ops
,
get_tokenlizer
from
groundingdino.util.misc
import
(
NestedTensor
,
accuracy
,
get_world_size
,
interpolate
,
inverse_sigmoid
,
is_dist_avail_and_initialized
,
nested_tensor_from_tensor_list
,
)
from
groundingdino.util.utils
import
get_phrases_from_posmap
from
groundingdino.util.visualizer
import
COCOVisualizer
from
groundingdino.util.vl_utils
import
create_positive_map_from_span
from
..registry
import
MODULE_BUILD_FUNCS
from
.backbone
import
build_backbone
from
.bertwarper
import
(
BertModelWarper
,
generate_masks_with_special_tokens
,
generate_masks_with_special_tokens_and_transfer_map
,
)
from
.transformer
import
build_transformer
from
.utils
import
MLP
,
ContrastiveEmbed
,
sigmoid_focal_loss
class
GroundingDINO
(
nn
.
Module
):
"""This is the Cross-Attention Detector module that performs object detection"""
def
__init__
(
self
,
backbone
,
transformer
,
num_queries
,
aux_loss
=
False
,
iter_update
=
False
,
query_dim
=
2
,
num_feature_levels
=
1
,
nheads
=
8
,
# two stage
two_stage_type
=
"no"
,
# ['no', 'standard']
dec_pred_bbox_embed_share
=
True
,
two_stage_class_embed_share
=
True
,
two_stage_bbox_embed_share
=
True
,
num_patterns
=
0
,
dn_number
=
100
,
dn_box_noise_scale
=
0.4
,
dn_label_noise_ratio
=
0.5
,
dn_labelbook_size
=
100
,
text_encoder_type
=
"bert-base-uncased"
,
sub_sentence_present
=
True
,
max_text_len
=
256
,
):
"""Initializes the model.
Parameters:
backbone: torch module of the backbone to be used. See backbone.py
transformer: torch module of the transformer architecture. See transformer.py
num_queries: number of object queries, ie detection slot. This is the maximal number of objects
Conditional DETR can detect in a single image. For COCO, we recommend 100 queries.
aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used.
"""
super
().
__init__
()
self
.
num_queries
=
num_queries
self
.
transformer
=
transformer
self
.
hidden_dim
=
hidden_dim
=
transformer
.
d_model
self
.
num_feature_levels
=
num_feature_levels
self
.
nheads
=
nheads
self
.
max_text_len
=
256
self
.
sub_sentence_present
=
sub_sentence_present
# setting query dim
self
.
query_dim
=
query_dim
assert
query_dim
==
4
# for dn training
self
.
num_patterns
=
num_patterns
self
.
dn_number
=
dn_number
self
.
dn_box_noise_scale
=
dn_box_noise_scale
self
.
dn_label_noise_ratio
=
dn_label_noise_ratio
self
.
dn_labelbook_size
=
dn_labelbook_size
# bert
self
.
tokenizer
=
get_tokenlizer
.
get_tokenlizer
(
text_encoder_type
)
self
.
bert
=
get_tokenlizer
.
get_pretrained_language_model
(
text_encoder_type
)
self
.
bert
.
pooler
.
dense
.
weight
.
requires_grad_
(
False
)
self
.
bert
.
pooler
.
dense
.
bias
.
requires_grad_
(
False
)
self
.
bert
=
BertModelWarper
(
bert_model
=
self
.
bert
)
self
.
feat_map
=
nn
.
Linear
(
self
.
bert
.
config
.
hidden_size
,
self
.
hidden_dim
,
bias
=
True
)
nn
.
init
.
constant_
(
self
.
feat_map
.
bias
.
data
,
0
)
nn
.
init
.
xavier_uniform_
(
self
.
feat_map
.
weight
.
data
)
# freeze
# special tokens
self
.
specical_tokens
=
self
.
tokenizer
.
convert_tokens_to_ids
([
"[CLS]"
,
"[SEP]"
,
"."
,
"?"
])
# prepare input projection layers
if
num_feature_levels
>
1
:
num_backbone_outs
=
len
(
backbone
.
num_channels
)
input_proj_list
=
[]
for
_
in
range
(
num_backbone_outs
):
in_channels
=
backbone
.
num_channels
[
_
]
input_proj_list
.
append
(
nn
.
Sequential
(
nn
.
Conv2d
(
in_channels
,
hidden_dim
,
kernel_size
=
1
),
nn
.
GroupNorm
(
32
,
hidden_dim
),
)
)
for
_
in
range
(
num_feature_levels
-
num_backbone_outs
):
input_proj_list
.
append
(
nn
.
Sequential
(
nn
.
Conv2d
(
in_channels
,
hidden_dim
,
kernel_size
=
3
,
stride
=
2
,
padding
=
1
),
nn
.
GroupNorm
(
32
,
hidden_dim
),
)
)
in_channels
=
hidden_dim
self
.
input_proj
=
nn
.
ModuleList
(
input_proj_list
)
else
:
assert
two_stage_type
==
"no"
,
"two_stage_type should be no if num_feature_levels=1 !!!"
self
.
input_proj
=
nn
.
ModuleList
(
[
nn
.
Sequential
(
nn
.
Conv2d
(
backbone
.
num_channels
[
-
1
],
hidden_dim
,
kernel_size
=
1
),
nn
.
GroupNorm
(
32
,
hidden_dim
),
)
]
)
self
.
backbone
=
backbone
self
.
aux_loss
=
aux_loss
self
.
box_pred_damping
=
box_pred_damping
=
None
self
.
iter_update
=
iter_update
assert
iter_update
,
"Why not iter_update?"
# prepare pred layers
self
.
dec_pred_bbox_embed_share
=
dec_pred_bbox_embed_share
# prepare class & box embed
_class_embed
=
ContrastiveEmbed
()
_bbox_embed
=
MLP
(
hidden_dim
,
hidden_dim
,
4
,
3
)
nn
.
init
.
constant_
(
_bbox_embed
.
layers
[
-
1
].
weight
.
data
,
0
)
nn
.
init
.
constant_
(
_bbox_embed
.
layers
[
-
1
].
bias
.
data
,
0
)
if
dec_pred_bbox_embed_share
:
box_embed_layerlist
=
[
_bbox_embed
for
i
in
range
(
transformer
.
num_decoder_layers
)]
else
:
box_embed_layerlist
=
[
copy
.
deepcopy
(
_bbox_embed
)
for
i
in
range
(
transformer
.
num_decoder_layers
)
]
class_embed_layerlist
=
[
_class_embed
for
i
in
range
(
transformer
.
num_decoder_layers
)]
self
.
bbox_embed
=
nn
.
ModuleList
(
box_embed_layerlist
)
self
.
class_embed
=
nn
.
ModuleList
(
class_embed_layerlist
)
self
.
transformer
.
decoder
.
bbox_embed
=
self
.
bbox_embed
self
.
transformer
.
decoder
.
class_embed
=
self
.
class_embed
# two stage
self
.
two_stage_type
=
two_stage_type
assert
two_stage_type
in
[
"no"
,
"standard"
],
"unknown param {} of two_stage_type"
.
format
(
two_stage_type
)
if
two_stage_type
!=
"no"
:
if
two_stage_bbox_embed_share
:
assert
dec_pred_bbox_embed_share
self
.
transformer
.
enc_out_bbox_embed
=
_bbox_embed
else
:
self
.
transformer
.
enc_out_bbox_embed
=
copy
.
deepcopy
(
_bbox_embed
)
if
two_stage_class_embed_share
:
assert
dec_pred_bbox_embed_share
self
.
transformer
.
enc_out_class_embed
=
_class_embed
else
:
self
.
transformer
.
enc_out_class_embed
=
copy
.
deepcopy
(
_class_embed
)
self
.
refpoint_embed
=
None
self
.
_reset_parameters
()
def
_reset_parameters
(
self
):
# init input_proj
for
proj
in
self
.
input_proj
:
nn
.
init
.
xavier_uniform_
(
proj
[
0
].
weight
,
gain
=
1
)
nn
.
init
.
constant_
(
proj
[
0
].
bias
,
0
)
def
set_image_tensor
(
self
,
samples
:
NestedTensor
):
if
isinstance
(
samples
,
(
list
,
torch
.
Tensor
)):
samples
=
nested_tensor_from_tensor_list
(
samples
)
self
.
features
,
self
.
poss
=
self
.
backbone
(
samples
)
def
unset_image_tensor
(
self
):
if
hasattr
(
self
,
'features'
):
del
self
.
features
if
hasattr
(
self
,
'poss'
):
del
self
.
poss
def
set_image_features
(
self
,
features
,
poss
):
self
.
features
=
features
self
.
poss
=
poss
def
init_ref_points
(
self
,
use_num_queries
):
self
.
refpoint_embed
=
nn
.
Embedding
(
use_num_queries
,
self
.
query_dim
)
# def forward(self, samples: NestedTensor, targets: List = None, **kw):
def
forward
(
self
,
samples
:
NestedTensor
,
input_ids
:
Optional
[
torch
.
Tensor
]
=
None
,
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
position_ids
:
Optional
[
torch
.
Tensor
]
=
None
,
token_type_ids
:
Optional
[
torch
.
Tensor
]
=
None
,
text_self_attention_masks
:
Optional
[
torch
.
Tensor
]
=
None
):
"""The forward expects a NestedTensor, which consists of:
- samples.tensor: batched images, of shape [batch_size x 3 x H x W]
- samples.mask: a binary mask of shape [batch_size x H x W], containing 1 on padded pixels
It returns a dict with the following elements:
- "pred_logits": the classification logits (including no-object) for all queries.
Shape= [batch_size x num_queries x num_classes]
- "pred_boxes": The normalized boxes coordinates for all queries, represented as
(center_x, center_y, width, height). These values are normalized in [0, 1],
relative to the size of each individual image (disregarding possible padding).
See PostProcess for information on how to retrieve the unnormalized bounding box.
- "aux_outputs": Optional, only returned when auxilary losses are activated. It is a list of
dictionnaries containing the two above keys for each decoder layer.
"""
# if targets is None:
# captions = kw["captions"]
# else:
# captions = [t["caption"] for t in targets]
# # encoder texts
# tokenized = self.tokenizer(captions, padding="longest", return_tensors="pt").to(
# samples.device
# )
# (
# text_self_attention_masks,
# position_ids,
# cate_to_token_mask_list,
# ) = generate_masks_with_special_tokens_and_transfer_map(
# tokenized, self.specical_tokens, self.tokenizer
# )
# if text_self_attention_masks.shape[1] > self.max_text_len:
# text_self_attention_masks = text_self_attention_masks[
# :, : self.max_text_len, : self.max_text_len
# ]
# position_ids = position_ids[:, : self.max_text_len]
# tokenized["input_ids"] = tokenized["input_ids"][:, : self.max_text_len]
# tokenized["attention_mask"] = tokenized["attention_mask"][:, : self.max_text_len]
# tokenized["token_type_ids"] = tokenized["token_type_ids"][:, : self.max_text_len]
# # extract text embeddings
# if self.sub_sentence_present:
# tokenized_for_encoder = {k: v for k, v in tokenized.items() if k != "attention_mask"}
# tokenized_for_encoder["attention_mask"] = text_self_attention_masks
# tokenized_for_encoder["position_ids"] = position_ids
# else:
# # import ipdb; ipdb.set_trace()
# tokenized_for_encoder = tokenized
# ---------------------------输入不同--------------------------
tokenized_for_encoder
=
{}
tokenized_for_encoder
[
"input_ids"
]
=
input_ids
tokenized_for_encoder
[
"attention_mask"
]
=
text_self_attention_masks
tokenized_for_encoder
[
"position_ids"
]
=
position_ids
tokenized_for_encoder
[
"token_type_ids"
]
=
token_type_ids
# --------------------------输入不同结束------------------------
bert_output
=
self
.
bert
(
**
tokenized_for_encoder
)
# bs, 195, 768
encoded_text
=
self
.
feat_map
(
bert_output
[
"last_hidden_state"
])
# bs, 195, d_model
#------------------------------2---------------------------------
# text_token_mask = tokenized.attention_mask.bool() # bs, 195
#-----------------------------------------------------------------
text_token_mask
=
attention_mask
.
bool
()
# bs, 195
# text_token_mask: True for nomask, False for mask
# text_self_attention_masks: True for nomask, False for mask
if
encoded_text
.
shape
[
1
]
>
self
.
max_text_len
:
encoded_text
=
encoded_text
[:,
:
self
.
max_text_len
,
:]
text_token_mask
=
text_token_mask
[:,
:
self
.
max_text_len
]
position_ids
=
position_ids
[:,
:
self
.
max_text_len
]
text_self_attention_masks
=
text_self_attention_masks
[
:,
:
self
.
max_text_len
,
:
self
.
max_text_len
]
text_dict
=
{
"encoded_text"
:
encoded_text
,
# bs, 195, d_model
"text_token_mask"
:
text_token_mask
,
# bs, 195
"position_ids"
:
position_ids
,
# bs, 195
"text_self_attention_masks"
:
text_self_attention_masks
,
# bs, 195,195
}
# import ipdb; ipdb.set_trace()
if
isinstance
(
samples
,
(
list
,
torch
.
Tensor
)):
samples
=
nested_tensor_from_tensor_list
(
samples
)
if
not
hasattr
(
self
,
'features'
)
or
not
hasattr
(
self
,
'poss'
):
self
.
set_image_tensor
(
samples
)
srcs
=
[]
masks
=
[]
for
l
,
feat
in
enumerate
(
self
.
features
):
src
,
mask
=
feat
.
decompose
()
srcs
.
append
(
self
.
input_proj
[
l
](
src
))
masks
.
append
(
mask
)
assert
mask
is
not
None
if
self
.
num_feature_levels
>
len
(
srcs
):
_len_srcs
=
len
(
srcs
)
for
l
in
range
(
_len_srcs
,
self
.
num_feature_levels
):
if
l
==
_len_srcs
:
src
=
self
.
input_proj
[
l
](
self
.
features
[
-
1
].
tensors
)
else
:
src
=
self
.
input_proj
[
l
](
srcs
[
-
1
])
m
=
samples
.
mask
mask
=
F
.
interpolate
(
m
[
None
].
float
(),
size
=
src
.
shape
[
-
2
:]).
to
(
torch
.
bool
)[
0
]
pos_l
=
self
.
backbone
[
1
](
NestedTensor
(
src
,
mask
)).
to
(
src
.
dtype
)
srcs
.
append
(
src
)
masks
.
append
(
mask
)
self
.
poss
.
append
(
pos_l
)
input_query_bbox
=
input_query_label
=
attn_mask
=
dn_meta
=
None
hs
,
reference
,
hs_enc
,
ref_enc
,
init_box_proposal
=
self
.
transformer
(
srcs
,
masks
,
input_query_bbox
,
self
.
poss
,
input_query_label
,
attn_mask
,
text_dict
)
# deformable-detr-like anchor update
outputs_coord_list
=
[]
for
dec_lid
,
(
layer_ref_sig
,
layer_bbox_embed
,
layer_hs
)
in
enumerate
(
zip
(
reference
[:
-
1
],
self
.
bbox_embed
,
hs
)
):
layer_delta_unsig
=
layer_bbox_embed
(
layer_hs
)
layer_outputs_unsig
=
layer_delta_unsig
+
inverse_sigmoid
(
layer_ref_sig
)
layer_outputs_unsig
=
layer_outputs_unsig
.
sigmoid
()
outputs_coord_list
.
append
(
layer_outputs_unsig
)
outputs_coord_list
=
torch
.
stack
(
outputs_coord_list
)
# output
outputs_class
=
torch
.
stack
(
[
layer_cls_embed
(
layer_hs
,
text_dict
)
for
layer_cls_embed
,
layer_hs
in
zip
(
self
.
class_embed
,
hs
)
]
)
out
=
{
"pred_logits"
:
outputs_class
[
-
1
],
"pred_boxes"
:
outputs_coord_list
[
-
1
]}
# # for intermediate outputs
# if self.aux_loss:
# out['aux_outputs'] = self._set_aux_loss(outputs_class, outputs_coord_list)
# # for encoder output
# if hs_enc is not None:
# # prepare intermediate outputs
# interm_coord = ref_enc[-1]
# interm_class = self.transformer.enc_out_class_embed(hs_enc[-1], text_dict)
# out['interm_outputs'] = {'pred_logits': interm_class, 'pred_boxes': interm_coord}
# out['interm_outputs_for_matching_pre'] = {'pred_logits': interm_class, 'pred_boxes': init_box_proposal}
#-------------------------2---------------------------------
# unset_image_tensor = kw.get('unset_image_tensor', True)
# if unset_image_tensor:
# self.unset_image_tensor() ## If necessary
#-----------------------------------------------------------
return
out
@
torch
.
jit
.
unused
def
_set_aux_loss
(
self
,
outputs_class
,
outputs_coord
):
# this is a workaround to make torchscript happy, as torchscript
# doesn't support dictionary with non-homogeneous values, such
# as a dict having both a Tensor and a list.
return
[
{
"pred_logits"
:
a
,
"pred_boxes"
:
b
}
for
a
,
b
in
zip
(
outputs_class
[:
-
1
],
outputs_coord
[:
-
1
])
]
@
MODULE_BUILD_FUNCS
.
registe_with_name
(
module_name
=
"groundingdino"
)
def
build_groundingdino
(
args
):
backbone
=
build_backbone
(
args
)
transformer
=
build_transformer
(
args
)
dn_labelbook_size
=
args
.
dn_labelbook_size
dec_pred_bbox_embed_share
=
args
.
dec_pred_bbox_embed_share
sub_sentence_present
=
args
.
sub_sentence_present
model
=
GroundingDINO
(
backbone
,
transformer
,
num_queries
=
args
.
num_queries
,
aux_loss
=
True
,
iter_update
=
True
,
query_dim
=
4
,
num_feature_levels
=
args
.
num_feature_levels
,
nheads
=
args
.
nheads
,
dec_pred_bbox_embed_share
=
dec_pred_bbox_embed_share
,
two_stage_type
=
args
.
two_stage_type
,
two_stage_bbox_embed_share
=
args
.
two_stage_bbox_embed_share
,
two_stage_class_embed_share
=
args
.
two_stage_class_embed_share
,
num_patterns
=
args
.
num_patterns
,
dn_number
=
0
,
dn_box_noise_scale
=
args
.
dn_box_noise_scale
,
dn_label_noise_ratio
=
args
.
dn_label_noise_ratio
,
dn_labelbook_size
=
dn_labelbook_size
,
text_encoder_type
=
args
.
text_encoder_type
,
sub_sentence_present
=
sub_sentence_present
,
max_text_len
=
args
.
max_text_len
,
)
return
model
groundingdino/models/GroundingDINO/groundingdino_torch.py
0 → 100644
View file @
34e4011b
# ------------------------------------------------------------------------
# Grounding DINO
# url: https://github.com/IDEA-Research/GroundingDINO
# Copyright (c) 2023 IDEA. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------
# Conditional DETR model and criterion classes.
# Copyright (c) 2021 Microsoft. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------
# Modified from DETR (https://github.com/facebookresearch/detr)
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
# ------------------------------------------------------------------------
# Modified from Deformable DETR (https://github.com/fundamentalvision/Deformable-DETR)
# Copyright (c) 2020 SenseTime. All Rights Reserved.
# ------------------------------------------------------------------------
import
copy
from
typing
import
List
,
Optional
import
torch
import
torch.nn.functional
as
F
from
torch
import
nn
from
torchvision.ops.boxes
import
nms
from
transformers
import
AutoTokenizer
,
BertModel
,
BertTokenizer
,
RobertaModel
,
RobertaTokenizerFast
from
groundingdino.util
import
box_ops
,
get_tokenlizer
from
groundingdino.util.misc
import
(
NestedTensor
,
accuracy
,
get_world_size
,
interpolate
,
inverse_sigmoid
,
is_dist_avail_and_initialized
,
nested_tensor_from_tensor_list
,
)
from
groundingdino.util.utils
import
get_phrases_from_posmap
from
groundingdino.util.visualizer
import
COCOVisualizer
from
groundingdino.util.vl_utils
import
create_positive_map_from_span
from
..registry
import
MODULE_BUILD_FUNCS
from
.backbone
import
build_backbone
from
.bertwarper
import
(
BertModelWarper
,
generate_masks_with_special_tokens
,
generate_masks_with_special_tokens_and_transfer_map
,
)
from
.transformer
import
build_transformer
from
.utils
import
MLP
,
ContrastiveEmbed
,
sigmoid_focal_loss
class
GroundingDINO
(
nn
.
Module
):
"""This is the Cross-Attention Detector module that performs object detection"""
def
__init__
(
self
,
backbone
,
transformer
,
num_queries
,
aux_loss
=
False
,
iter_update
=
False
,
query_dim
=
2
,
num_feature_levels
=
1
,
nheads
=
8
,
# two stage
two_stage_type
=
"no"
,
# ['no', 'standard']
dec_pred_bbox_embed_share
=
True
,
two_stage_class_embed_share
=
True
,
two_stage_bbox_embed_share
=
True
,
num_patterns
=
0
,
dn_number
=
100
,
dn_box_noise_scale
=
0.4
,
dn_label_noise_ratio
=
0.5
,
dn_labelbook_size
=
100
,
text_encoder_type
=
"bert-base-uncased"
,
sub_sentence_present
=
True
,
max_text_len
=
256
,
):
"""Initializes the model.
Parameters:
backbone: torch module of the backbone to be used. See backbone.py
transformer: torch module of the transformer architecture. See transformer.py
num_queries: number of object queries, ie detection slot. This is the maximal number of objects
Conditional DETR can detect in a single image. For COCO, we recommend 100 queries.
aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used.
"""
super
().
__init__
()
self
.
num_queries
=
num_queries
self
.
transformer
=
transformer
self
.
hidden_dim
=
hidden_dim
=
transformer
.
d_model
self
.
num_feature_levels
=
num_feature_levels
self
.
nheads
=
nheads
self
.
max_text_len
=
256
self
.
sub_sentence_present
=
sub_sentence_present
# setting query dim
self
.
query_dim
=
query_dim
assert
query_dim
==
4
# for dn training
self
.
num_patterns
=
num_patterns
self
.
dn_number
=
dn_number
self
.
dn_box_noise_scale
=
dn_box_noise_scale
self
.
dn_label_noise_ratio
=
dn_label_noise_ratio
self
.
dn_labelbook_size
=
dn_labelbook_size
# bert
self
.
tokenizer
=
get_tokenlizer
.
get_tokenlizer
(
text_encoder_type
)
self
.
bert
=
get_tokenlizer
.
get_pretrained_language_model
(
text_encoder_type
)
self
.
bert
.
pooler
.
dense
.
weight
.
requires_grad_
(
False
)
self
.
bert
.
pooler
.
dense
.
bias
.
requires_grad_
(
False
)
self
.
bert
=
BertModelWarper
(
bert_model
=
self
.
bert
)
self
.
feat_map
=
nn
.
Linear
(
self
.
bert
.
config
.
hidden_size
,
self
.
hidden_dim
,
bias
=
True
)
nn
.
init
.
constant_
(
self
.
feat_map
.
bias
.
data
,
0
)
nn
.
init
.
xavier_uniform_
(
self
.
feat_map
.
weight
.
data
)
# freeze
# special tokens
self
.
specical_tokens
=
self
.
tokenizer
.
convert_tokens_to_ids
([
"[CLS]"
,
"[SEP]"
,
"."
,
"?"
])
# prepare input projection layers
if
num_feature_levels
>
1
:
num_backbone_outs
=
len
(
backbone
.
num_channels
)
input_proj_list
=
[]
for
_
in
range
(
num_backbone_outs
):
in_channels
=
backbone
.
num_channels
[
_
]
input_proj_list
.
append
(
nn
.
Sequential
(
nn
.
Conv2d
(
in_channels
,
hidden_dim
,
kernel_size
=
1
),
nn
.
GroupNorm
(
32
,
hidden_dim
),
)
)
for
_
in
range
(
num_feature_levels
-
num_backbone_outs
):
input_proj_list
.
append
(
nn
.
Sequential
(
nn
.
Conv2d
(
in_channels
,
hidden_dim
,
kernel_size
=
3
,
stride
=
2
,
padding
=
1
),
nn
.
GroupNorm
(
32
,
hidden_dim
),
)
)
in_channels
=
hidden_dim
self
.
input_proj
=
nn
.
ModuleList
(
input_proj_list
)
else
:
assert
two_stage_type
==
"no"
,
"two_stage_type should be no if num_feature_levels=1 !!!"
self
.
input_proj
=
nn
.
ModuleList
(
[
nn
.
Sequential
(
nn
.
Conv2d
(
backbone
.
num_channels
[
-
1
],
hidden_dim
,
kernel_size
=
1
),
nn
.
GroupNorm
(
32
,
hidden_dim
),
)
]
)
self
.
backbone
=
backbone
self
.
aux_loss
=
aux_loss
self
.
box_pred_damping
=
box_pred_damping
=
None
self
.
iter_update
=
iter_update
assert
iter_update
,
"Why not iter_update?"
# prepare pred layers
self
.
dec_pred_bbox_embed_share
=
dec_pred_bbox_embed_share
# prepare class & box embed
_class_embed
=
ContrastiveEmbed
()
_bbox_embed
=
MLP
(
hidden_dim
,
hidden_dim
,
4
,
3
)
nn
.
init
.
constant_
(
_bbox_embed
.
layers
[
-
1
].
weight
.
data
,
0
)
nn
.
init
.
constant_
(
_bbox_embed
.
layers
[
-
1
].
bias
.
data
,
0
)
if
dec_pred_bbox_embed_share
:
box_embed_layerlist
=
[
_bbox_embed
for
i
in
range
(
transformer
.
num_decoder_layers
)]
else
:
box_embed_layerlist
=
[
copy
.
deepcopy
(
_bbox_embed
)
for
i
in
range
(
transformer
.
num_decoder_layers
)
]
class_embed_layerlist
=
[
_class_embed
for
i
in
range
(
transformer
.
num_decoder_layers
)]
self
.
bbox_embed
=
nn
.
ModuleList
(
box_embed_layerlist
)
self
.
class_embed
=
nn
.
ModuleList
(
class_embed_layerlist
)
self
.
transformer
.
decoder
.
bbox_embed
=
self
.
bbox_embed
self
.
transformer
.
decoder
.
class_embed
=
self
.
class_embed
# two stage
self
.
two_stage_type
=
two_stage_type
assert
two_stage_type
in
[
"no"
,
"standard"
],
"unknown param {} of two_stage_type"
.
format
(
two_stage_type
)
if
two_stage_type
!=
"no"
:
if
two_stage_bbox_embed_share
:
assert
dec_pred_bbox_embed_share
self
.
transformer
.
enc_out_bbox_embed
=
_bbox_embed
else
:
self
.
transformer
.
enc_out_bbox_embed
=
copy
.
deepcopy
(
_bbox_embed
)
if
two_stage_class_embed_share
:
assert
dec_pred_bbox_embed_share
self
.
transformer
.
enc_out_class_embed
=
_class_embed
else
:
self
.
transformer
.
enc_out_class_embed
=
copy
.
deepcopy
(
_class_embed
)
self
.
refpoint_embed
=
None
self
.
_reset_parameters
()
def
_reset_parameters
(
self
):
# init input_proj
for
proj
in
self
.
input_proj
:
nn
.
init
.
xavier_uniform_
(
proj
[
0
].
weight
,
gain
=
1
)
nn
.
init
.
constant_
(
proj
[
0
].
bias
,
0
)
def
set_image_tensor
(
self
,
samples
:
NestedTensor
):
if
isinstance
(
samples
,
(
list
,
torch
.
Tensor
)):
samples
=
nested_tensor_from_tensor_list
(
samples
)
self
.
features
,
self
.
poss
=
self
.
backbone
(
samples
)
def
unset_image_tensor
(
self
):
if
hasattr
(
self
,
'features'
):
del
self
.
features
if
hasattr
(
self
,
'poss'
):
del
self
.
poss
def
set_image_features
(
self
,
features
,
poss
):
self
.
features
=
features
self
.
poss
=
poss
def
init_ref_points
(
self
,
use_num_queries
):
self
.
refpoint_embed
=
nn
.
Embedding
(
use_num_queries
,
self
.
query_dim
)
def
forward
(
self
,
samples
:
NestedTensor
,
targets
:
List
=
None
,
**
kw
):
# def forward(self, samples: NestedTensor,
# input_ids: Optional[torch.Tensor] = None,
# attention_mask: Optional[torch.Tensor] = None,
# position_ids: Optional[torch.Tensor] = None,
# token_type_ids: Optional[torch.Tensor] = None,
# text_self_attention_masks: Optional[torch.Tensor] = None):
"""The forward expects a NestedTensor, which consists of:
- samples.tensor: batched images, of shape [batch_size x 3 x H x W]
- samples.mask: a binary mask of shape [batch_size x H x W], containing 1 on padded pixels
It returns a dict with the following elements:
- "pred_logits": the classification logits (including no-object) for all queries.
Shape= [batch_size x num_queries x num_classes]
- "pred_boxes": The normalized boxes coordinates for all queries, represented as
(center_x, center_y, width, height). These values are normalized in [0, 1],
relative to the size of each individual image (disregarding possible padding).
See PostProcess for information on how to retrieve the unnormalized bounding box.
- "aux_outputs": Optional, only returned when auxilary losses are activated. It is a list of
dictionnaries containing the two above keys for each decoder layer.
"""
if
targets
is
None
:
captions
=
kw
[
"captions"
]
else
:
captions
=
[
t
[
"caption"
]
for
t
in
targets
]
# encoder texts
tokenized
=
self
.
tokenizer
(
captions
,
padding
=
"longest"
,
return_tensors
=
"pt"
).
to
(
samples
.
device
)
(
text_self_attention_masks
,
position_ids
,
cate_to_token_mask_list
,
)
=
generate_masks_with_special_tokens_and_transfer_map
(
tokenized
,
self
.
specical_tokens
,
self
.
tokenizer
)
if
text_self_attention_masks
.
shape
[
1
]
>
self
.
max_text_len
:
text_self_attention_masks
=
text_self_attention_masks
[
:,
:
self
.
max_text_len
,
:
self
.
max_text_len
]
position_ids
=
position_ids
[:,
:
self
.
max_text_len
]
tokenized
[
"input_ids"
]
=
tokenized
[
"input_ids"
][:,
:
self
.
max_text_len
]
tokenized
[
"attention_mask"
]
=
tokenized
[
"attention_mask"
][:,
:
self
.
max_text_len
]
tokenized
[
"token_type_ids"
]
=
tokenized
[
"token_type_ids"
][:,
:
self
.
max_text_len
]
# extract text embeddings
if
self
.
sub_sentence_present
:
tokenized_for_encoder
=
{
k
:
v
for
k
,
v
in
tokenized
.
items
()
if
k
!=
"attention_mask"
}
tokenized_for_encoder
[
"attention_mask"
]
=
text_self_attention_masks
tokenized_for_encoder
[
"position_ids"
]
=
position_ids
else
:
# import ipdb; ipdb.set_trace()
tokenized_for_encoder
=
tokenized
# ---------------------------输入不同--------------------------
# tokenized_for_encoder ={}
# tokenized_for_encoder["input_ids"] = input_ids
# tokenized_for_encoder["attention_mask"] = text_self_attention_masks
# tokenized_for_encoder["position_ids"] = position_ids
# tokenized_for_encoder["token_type_ids"] = token_type_ids
# --------------------------输入不同结束------------------------
bert_output
=
self
.
bert
(
**
tokenized_for_encoder
)
# bs, 195, 768
encoded_text
=
self
.
feat_map
(
bert_output
[
"last_hidden_state"
])
# bs, 195, d_model
#------------------------------2---------------------------------
text_token_mask
=
tokenized
.
attention_mask
.
bool
()
# bs, 195
#-----------------------------------------------------------------
# text_token_mask = attention_mask.bool() # bs, 195
# text_token_mask: True for nomask, False for mask
# text_self_attention_masks: True for nomask, False for mask
if
encoded_text
.
shape
[
1
]
>
self
.
max_text_len
:
encoded_text
=
encoded_text
[:,
:
self
.
max_text_len
,
:]
text_token_mask
=
text_token_mask
[:,
:
self
.
max_text_len
]
position_ids
=
position_ids
[:,
:
self
.
max_text_len
]
text_self_attention_masks
=
text_self_attention_masks
[
:,
:
self
.
max_text_len
,
:
self
.
max_text_len
]
text_dict
=
{
"encoded_text"
:
encoded_text
,
# bs, 195, d_model
"text_token_mask"
:
text_token_mask
,
# bs, 195
"position_ids"
:
position_ids
,
# bs, 195
"text_self_attention_masks"
:
text_self_attention_masks
,
# bs, 195,195
}
# import ipdb; ipdb.set_trace()
if
isinstance
(
samples
,
(
list
,
torch
.
Tensor
)):
samples
=
nested_tensor_from_tensor_list
(
samples
)
if
not
hasattr
(
self
,
'features'
)
or
not
hasattr
(
self
,
'poss'
):
self
.
set_image_tensor
(
samples
)
srcs
=
[]
masks
=
[]
for
l
,
feat
in
enumerate
(
self
.
features
):
src
,
mask
=
feat
.
decompose
()
srcs
.
append
(
self
.
input_proj
[
l
](
src
))
masks
.
append
(
mask
)
assert
mask
is
not
None
if
self
.
num_feature_levels
>
len
(
srcs
):
_len_srcs
=
len
(
srcs
)
for
l
in
range
(
_len_srcs
,
self
.
num_feature_levels
):
if
l
==
_len_srcs
:
src
=
self
.
input_proj
[
l
](
self
.
features
[
-
1
].
tensors
)
else
:
src
=
self
.
input_proj
[
l
](
srcs
[
-
1
])
m
=
samples
.
mask
mask
=
F
.
interpolate
(
m
[
None
].
float
(),
size
=
src
.
shape
[
-
2
:]).
to
(
torch
.
bool
)[
0
]
pos_l
=
self
.
backbone
[
1
](
NestedTensor
(
src
,
mask
)).
to
(
src
.
dtype
)
srcs
.
append
(
src
)
masks
.
append
(
mask
)
self
.
poss
.
append
(
pos_l
)
input_query_bbox
=
input_query_label
=
attn_mask
=
dn_meta
=
None
hs
,
reference
,
hs_enc
,
ref_enc
,
init_box_proposal
=
self
.
transformer
(
srcs
,
masks
,
input_query_bbox
,
self
.
poss
,
input_query_label
,
attn_mask
,
text_dict
)
# deformable-detr-like anchor update
outputs_coord_list
=
[]
for
dec_lid
,
(
layer_ref_sig
,
layer_bbox_embed
,
layer_hs
)
in
enumerate
(
zip
(
reference
[:
-
1
],
self
.
bbox_embed
,
hs
)
):
layer_delta_unsig
=
layer_bbox_embed
(
layer_hs
)
layer_outputs_unsig
=
layer_delta_unsig
+
inverse_sigmoid
(
layer_ref_sig
)
layer_outputs_unsig
=
layer_outputs_unsig
.
sigmoid
()
outputs_coord_list
.
append
(
layer_outputs_unsig
)
outputs_coord_list
=
torch
.
stack
(
outputs_coord_list
)
# output
outputs_class
=
torch
.
stack
(
[
layer_cls_embed
(
layer_hs
,
text_dict
)
for
layer_cls_embed
,
layer_hs
in
zip
(
self
.
class_embed
,
hs
)
]
)
out
=
{
"pred_logits"
:
outputs_class
[
-
1
],
"pred_boxes"
:
outputs_coord_list
[
-
1
]}
# # for intermediate outputs
# if self.aux_loss:
# out['aux_outputs'] = self._set_aux_loss(outputs_class, outputs_coord_list)
# # for encoder output
# if hs_enc is not None:
# # prepare intermediate outputs
# interm_coord = ref_enc[-1]
# interm_class = self.transformer.enc_out_class_embed(hs_enc[-1], text_dict)
# out['interm_outputs'] = {'pred_logits': interm_class, 'pred_boxes': interm_coord}
# out['interm_outputs_for_matching_pre'] = {'pred_logits': interm_class, 'pred_boxes': init_box_proposal}
#-------------------------2---------------------------------
unset_image_tensor
=
kw
.
get
(
'unset_image_tensor'
,
True
)
if
unset_image_tensor
:
self
.
unset_image_tensor
()
## If necessary
#-----------------------------------------------------------
return
out
@
torch
.
jit
.
unused
def
_set_aux_loss
(
self
,
outputs_class
,
outputs_coord
):
# this is a workaround to make torchscript happy, as torchscript
# doesn't support dictionary with non-homogeneous values, such
# as a dict having both a Tensor and a list.
return
[
{
"pred_logits"
:
a
,
"pred_boxes"
:
b
}
for
a
,
b
in
zip
(
outputs_class
[:
-
1
],
outputs_coord
[:
-
1
])
]
@
MODULE_BUILD_FUNCS
.
registe_with_name
(
module_name
=
"groundingdino"
)
def
build_groundingdino
(
args
):
backbone
=
build_backbone
(
args
)
transformer
=
build_transformer
(
args
)
dn_labelbook_size
=
args
.
dn_labelbook_size
dec_pred_bbox_embed_share
=
args
.
dec_pred_bbox_embed_share
sub_sentence_present
=
args
.
sub_sentence_present
model
=
GroundingDINO
(
backbone
,
transformer
,
num_queries
=
args
.
num_queries
,
aux_loss
=
True
,
iter_update
=
True
,
query_dim
=
4
,
num_feature_levels
=
args
.
num_feature_levels
,
nheads
=
args
.
nheads
,
dec_pred_bbox_embed_share
=
dec_pred_bbox_embed_share
,
two_stage_type
=
args
.
two_stage_type
,
two_stage_bbox_embed_share
=
args
.
two_stage_bbox_embed_share
,
two_stage_class_embed_share
=
args
.
two_stage_class_embed_share
,
num_patterns
=
args
.
num_patterns
,
dn_number
=
0
,
dn_box_noise_scale
=
args
.
dn_box_noise_scale
,
dn_label_noise_ratio
=
args
.
dn_label_noise_ratio
,
dn_labelbook_size
=
dn_labelbook_size
,
text_encoder_type
=
args
.
text_encoder_type
,
sub_sentence_present
=
sub_sentence_present
,
max_text_len
=
args
.
max_text_len
,
)
return
model
groundingdino/models/GroundingDINO/ms_deform_attn copy.py
0 → 100644
View file @
34e4011b
# ------------------------------------------------------------------------
# Grounding DINO
# url: https://github.com/IDEA-Research/GroundingDINO
# Copyright (c) 2023 IDEA. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------
# Deformable DETR
# Copyright (c) 2020 SenseTime. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------------------------------
# Modified from:
# https://github.com/fundamentalvision/Deformable-DETR/blob/main/models/ops/functions/ms_deform_attn_func.py
# https://github.com/fundamentalvision/Deformable-DETR/blob/main/models/ops/modules/ms_deform_attn.py
# https://github.com/open-mmlab/mmcv/blob/master/mmcv/ops/multi_scale_deform_attn.py
# ------------------------------------------------------------------------------------------------
import
math
import
warnings
from
typing
import
Optional
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
torch.autograd
import
Function
from
torch.autograd.function
import
once_differentiable
from
torch.nn.init
import
constant_
,
xavier_uniform_
try
:
from
groundingdino
import
_C
except
:
warnings
.
warn
(
"Failed to load custom C++ ops. Running on CPU mode Only!"
)
# helpers
def
_is_power_of_2
(
n
):
if
(
not
isinstance
(
n
,
int
))
or
(
n
<
0
):
raise
ValueError
(
"invalid input for _is_power_of_2: {} (type: {})"
.
format
(
n
,
type
(
n
)))
return
(
n
&
(
n
-
1
)
==
0
)
and
n
!=
0
class
MultiScaleDeformableAttnFunction
(
Function
):
@
staticmethod
def
forward
(
ctx
,
value
,
value_spatial_shapes
,
value_level_start_index
,
sampling_locations
,
attention_weights
,
im2col_step
,
):
ctx
.
im2col_step
=
im2col_step
output
=
_C
.
ms_deform_attn_forward
(
value
,
value_spatial_shapes
,
value_level_start_index
,
sampling_locations
,
attention_weights
,
ctx
.
im2col_step
,
)
ctx
.
save_for_backward
(
value
,
value_spatial_shapes
,
value_level_start_index
,
sampling_locations
,
attention_weights
,
)
return
output
@
staticmethod
@
once_differentiable
def
backward
(
ctx
,
grad_output
):
(
value
,
value_spatial_shapes
,
value_level_start_index
,
sampling_locations
,
attention_weights
,
)
=
ctx
.
saved_tensors
grad_value
,
grad_sampling_loc
,
grad_attn_weight
=
_C
.
ms_deform_attn_backward
(
value
,
value_spatial_shapes
,
value_level_start_index
,
sampling_locations
,
attention_weights
,
grad_output
,
ctx
.
im2col_step
,
)
return
grad_value
,
None
,
None
,
grad_sampling_loc
,
grad_attn_weight
,
None
def
multi_scale_deformable_attn_pytorch
(
value
:
torch
.
Tensor
,
value_spatial_shapes
:
torch
.
Tensor
,
sampling_locations
:
torch
.
Tensor
,
attention_weights
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
bs
,
_
,
num_heads
,
embed_dims
=
value
.
shape
_
,
num_queries
,
num_heads
,
num_levels
,
num_points
,
_
=
sampling_locations
.
shape
value_list
=
value
.
split
([
H_
*
W_
for
H_
,
W_
in
value_spatial_shapes
],
dim
=
1
)
sampling_grids
=
2
*
sampling_locations
-
1
sampling_value_list
=
[]
for
level
,
(
H_
,
W_
)
in
enumerate
(
value_spatial_shapes
):
# bs, H_*W_, num_heads, embed_dims ->
# bs, H_*W_, num_heads*embed_dims ->
# bs, num_heads*embed_dims, H_*W_ ->
# bs*num_heads, embed_dims, H_, W_
value_l_
=
(
value_list
[
level
].
flatten
(
2
).
transpose
(
1
,
2
).
reshape
(
bs
*
num_heads
,
embed_dims
,
H_
,
W_
)
)
# bs, num_queries, num_heads, num_points, 2 ->
# bs, num_heads, num_queries, num_points, 2 ->
# bs*num_heads, num_queries, num_points, 2
sampling_grid_l_
=
sampling_grids
[:,
:,
:,
level
].
transpose
(
1
,
2
).
flatten
(
0
,
1
)
# bs*num_heads, embed_dims, num_queries, num_points
sampling_value_l_
=
F
.
grid_sample
(
value_l_
,
sampling_grid_l_
,
mode
=
"bilinear"
,
padding_mode
=
"zeros"
,
align_corners
=
False
)
sampling_value_list
.
append
(
sampling_value_l_
)
# (bs, num_queries, num_heads, num_levels, num_points) ->
# (bs, num_heads, num_queries, num_levels, num_points) ->
# (bs, num_heads, 1, num_queries, num_levels*num_points)
attention_weights
=
attention_weights
.
transpose
(
1
,
2
).
reshape
(
bs
*
num_heads
,
1
,
num_queries
,
num_levels
*
num_points
)
output
=
(
(
torch
.
stack
(
sampling_value_list
,
dim
=-
2
).
flatten
(
-
2
)
*
attention_weights
)
.
sum
(
-
1
)
.
view
(
bs
,
num_heads
*
embed_dims
,
num_queries
)
)
return
output
.
transpose
(
1
,
2
).
contiguous
()
class
MultiScaleDeformableAttention
(
nn
.
Module
):
"""Multi-Scale Deformable Attention Module used in Deformable-DETR
`Deformable DETR: Deformable Transformers for End-to-End Object Detection.
<https://arxiv.org/pdf/2010.04159.pdf>`_.
Args:
embed_dim (int): The embedding dimension of Attention. Default: 256.
num_heads (int): The number of attention heads. Default: 8.
num_levels (int): The number of feature map used in Attention. Default: 4.
num_points (int): The number of sampling points for each query
in each head. Default: 4.
img2col_steps (int): The step used in image_to_column. Defualt: 64.
dropout (float): Dropout layer used in output. Default: 0.1.
batch_first (bool): if ``True``, then the input and output tensor will be
provided as `(bs, n, embed_dim)`. Default: False. `(n, bs, embed_dim)`
"""
def
__init__
(
self
,
embed_dim
:
int
=
256
,
num_heads
:
int
=
8
,
num_levels
:
int
=
4
,
num_points
:
int
=
4
,
img2col_step
:
int
=
64
,
batch_first
:
bool
=
False
,
):
super
().
__init__
()
if
embed_dim
%
num_heads
!=
0
:
raise
ValueError
(
"embed_dim must be divisible by num_heads, but got {} and {}"
.
format
(
embed_dim
,
num_heads
)
)
head_dim
=
embed_dim
//
num_heads
self
.
batch_first
=
batch_first
if
not
_is_power_of_2
(
head_dim
):
warnings
.
warn
(
"""
You'd better set d_model in MSDeformAttn to make sure that
each dim of the attention head a power of 2, which is more efficient.
"""
)
self
.
im2col_step
=
img2col_step
self
.
embed_dim
=
embed_dim
self
.
num_heads
=
num_heads
self
.
num_levels
=
num_levels
self
.
num_points
=
num_points
self
.
sampling_offsets
=
nn
.
Linear
(
embed_dim
,
num_heads
*
num_levels
*
num_points
*
2
)
self
.
attention_weights
=
nn
.
Linear
(
embed_dim
,
num_heads
*
num_levels
*
num_points
)
self
.
value_proj
=
nn
.
Linear
(
embed_dim
,
embed_dim
)
self
.
output_proj
=
nn
.
Linear
(
embed_dim
,
embed_dim
)
self
.
init_weights
()
def
_reset_parameters
(
self
):
return
self
.
init_weights
()
def
init_weights
(
self
):
"""
Default initialization for Parameters of Module.
"""
constant_
(
self
.
sampling_offsets
.
weight
.
data
,
0.0
)
thetas
=
torch
.
arange
(
self
.
num_heads
,
dtype
=
torch
.
float32
)
*
(
2.0
*
math
.
pi
/
self
.
num_heads
)
grid_init
=
torch
.
stack
([
thetas
.
cos
(),
thetas
.
sin
()],
-
1
)
grid_init
=
(
(
grid_init
/
grid_init
.
abs
().
max
(
-
1
,
keepdim
=
True
)[
0
])
.
view
(
self
.
num_heads
,
1
,
1
,
2
)
.
repeat
(
1
,
self
.
num_levels
,
self
.
num_points
,
1
)
)
for
i
in
range
(
self
.
num_points
):
grid_init
[:,
:,
i
,
:]
*=
i
+
1
with
torch
.
no_grad
():
self
.
sampling_offsets
.
bias
=
nn
.
Parameter
(
grid_init
.
view
(
-
1
))
constant_
(
self
.
attention_weights
.
weight
.
data
,
0.0
)
constant_
(
self
.
attention_weights
.
bias
.
data
,
0.0
)
xavier_uniform_
(
self
.
value_proj
.
weight
.
data
)
constant_
(
self
.
value_proj
.
bias
.
data
,
0.0
)
xavier_uniform_
(
self
.
output_proj
.
weight
.
data
)
constant_
(
self
.
output_proj
.
bias
.
data
,
0.0
)
def
freeze_sampling_offsets
(
self
):
print
(
"Freeze sampling offsets"
)
self
.
sampling_offsets
.
weight
.
requires_grad
=
False
self
.
sampling_offsets
.
bias
.
requires_grad
=
False
def
freeze_attention_weights
(
self
):
print
(
"Freeze attention weights"
)
self
.
attention_weights
.
weight
.
requires_grad
=
False
self
.
attention_weights
.
bias
.
requires_grad
=
False
def
forward
(
self
,
query
:
torch
.
Tensor
,
key
:
Optional
[
torch
.
Tensor
]
=
None
,
value
:
Optional
[
torch
.
Tensor
]
=
None
,
query_pos
:
Optional
[
torch
.
Tensor
]
=
None
,
key_padding_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
reference_points
:
Optional
[
torch
.
Tensor
]
=
None
,
spatial_shapes
:
Optional
[
torch
.
Tensor
]
=
None
,
level_start_index
:
Optional
[
torch
.
Tensor
]
=
None
,
**
kwargs
)
->
torch
.
Tensor
:
"""Forward Function of MultiScaleDeformableAttention
Args:
query (torch.Tensor): Query embeddings with shape
`(num_query, bs, embed_dim)`
key (torch.Tensor): Key embeddings with shape
`(num_key, bs, embed_dim)`
value (torch.Tensor): Value embeddings with shape
`(num_key, bs, embed_dim)`
query_pos (torch.Tensor): The position embedding for `query`. Default: None.
key_padding_mask (torch.Tensor): ByteTensor for `query`, with shape `(bs, num_key)`,
indicating which elements within `key` to be ignored in attention.
reference_points (torch.Tensor): The normalized reference points
with shape `(bs, num_query, num_levels, 2)`,
all elements is range in [0, 1], top-left (0, 0),
bottom-right (1, 1), including padding are.
or `(N, Length_{query}, num_levels, 4)`, add additional
two dimensions `(h, w)` to form reference boxes.
spatial_shapes (torch.Tensor): Spatial shape of features in different levels.
With shape `(num_levels, 2)`, last dimension represents `(h, w)`.
level_start_index (torch.Tensor): The start index of each level. A tensor with
shape `(num_levels, )` which can be represented as
`[0, h_0 * w_0, h_0 * w_0 + h_1 * w_1, ...]`.
Returns:
torch.Tensor: forward results with shape `(num_query, bs, embed_dim)`
"""
if
value
is
None
:
value
=
query
if
query_pos
is
not
None
:
query
=
query
+
query_pos
if
not
self
.
batch_first
:
# change to (bs, num_query ,embed_dims)
query
=
query
.
permute
(
1
,
0
,
2
)
value
=
value
.
permute
(
1
,
0
,
2
)
bs
,
num_query
,
_
=
query
.
shape
bs
,
num_value
,
_
=
value
.
shape
assert
(
spatial_shapes
[:,
0
]
*
spatial_shapes
[:,
1
]).
sum
()
==
num_value
value
=
self
.
value_proj
(
value
)
if
key_padding_mask
is
not
None
:
value
=
value
.
masked_fill
(
key_padding_mask
[...,
None
],
float
(
0
))
value
=
value
.
view
(
bs
,
num_value
,
self
.
num_heads
,
-
1
)
sampling_offsets
=
self
.
sampling_offsets
(
query
).
view
(
bs
,
num_query
,
self
.
num_heads
,
self
.
num_levels
,
self
.
num_points
,
2
)
attention_weights
=
self
.
attention_weights
(
query
).
view
(
bs
,
num_query
,
self
.
num_heads
,
self
.
num_levels
*
self
.
num_points
)
attention_weights
=
attention_weights
.
softmax
(
-
1
)
attention_weights
=
attention_weights
.
view
(
bs
,
num_query
,
self
.
num_heads
,
self
.
num_levels
,
self
.
num_points
,
)
# bs, num_query, num_heads, num_levels, num_points, 2
if
reference_points
.
shape
[
-
1
]
==
2
:
offset_normalizer
=
torch
.
stack
([
spatial_shapes
[...,
1
],
spatial_shapes
[...,
0
]],
-
1
)
sampling_locations
=
(
reference_points
[:,
:,
None
,
:,
None
,
:]
+
sampling_offsets
/
offset_normalizer
[
None
,
None
,
None
,
:,
None
,
:]
)
elif
reference_points
.
shape
[
-
1
]
==
4
:
sampling_locations
=
(
reference_points
[:,
:,
None
,
:,
None
,
:
2
]
+
sampling_offsets
/
self
.
num_points
*
reference_points
[:,
:,
None
,
:,
None
,
2
:]
*
0.5
)
else
:
raise
ValueError
(
"Last dim of reference_points must be 2 or 4, but get {} instead."
.
format
(
reference_points
.
shape
[
-
1
]
)
)
if
torch
.
cuda
.
is_available
()
and
value
.
is_cuda
:
halffloat
=
False
if
value
.
dtype
==
torch
.
float16
:
halffloat
=
True
value
=
value
.
float
()
sampling_locations
=
sampling_locations
.
float
()
attention_weights
=
attention_weights
.
float
()
output
=
MultiScaleDeformableAttnFunction
.
apply
(
value
,
spatial_shapes
,
level_start_index
,
sampling_locations
,
attention_weights
,
self
.
im2col_step
,
)
if
halffloat
:
output
=
output
.
half
()
else
:
output
=
multi_scale_deformable_attn_pytorch
(
value
,
spatial_shapes
,
sampling_locations
,
attention_weights
)
output
=
self
.
output_proj
(
output
)
if
not
self
.
batch_first
:
output
=
output
.
permute
(
1
,
0
,
2
)
return
output
def
create_dummy_class
(
klass
,
dependency
,
message
=
""
):
"""
When a dependency of a class is not available, create a dummy class which throws ImportError
when used.
Args:
klass (str): name of the class.
dependency (str): name of the dependency.
message: extra message to print
Returns:
class: a class object
"""
err
=
"Cannot import '{}', therefore '{}' is not available."
.
format
(
dependency
,
klass
)
if
message
:
err
=
err
+
" "
+
message
class
_DummyMetaClass
(
type
):
# throw error on class attribute access
def
__getattr__
(
_
,
__
):
# noqa: B902
raise
ImportError
(
err
)
class
_Dummy
(
object
,
metaclass
=
_DummyMetaClass
):
# throw error on constructor
def
__init__
(
self
,
*
args
,
**
kwargs
):
raise
ImportError
(
err
)
return
_Dummy
def
create_dummy_func
(
func
,
dependency
,
message
=
""
):
"""
When a dependency of a function is not available, create a dummy function which throws
ImportError when used.
Args:
func (str): name of the function.
dependency (str or list[str]): name(s) of the dependency.
message: extra message to print
Returns:
function: a function object
"""
err
=
"Cannot import '{}', therefore '{}' is not available."
.
format
(
dependency
,
func
)
if
message
:
err
=
err
+
" "
+
message
if
isinstance
(
dependency
,
(
list
,
tuple
)):
dependency
=
","
.
join
(
dependency
)
def
_dummy
(
*
args
,
**
kwargs
):
raise
ImportError
(
err
)
return
_dummy
groundingdino/models/GroundingDINO/ms_deform_attn.py
0 → 100644
View file @
34e4011b
# ------------------------------------------------------------------------
# Grounding DINO
# url: https://github.com/IDEA-Research/GroundingDINO
# Copyright (c) 2023 IDEA. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------
# Deformable DETR
# Copyright (c) 2020 SenseTime. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------------------------------
# Modified from:
# https://github.com/fundamentalvision/Deformable-DETR/blob/main/models/ops/functions/ms_deform_attn_func.py
# https://github.com/fundamentalvision/Deformable-DETR/blob/main/models/ops/modules/ms_deform_attn.py
# https://github.com/open-mmlab/mmcv/blob/master/mmcv/ops/multi_scale_deform_attn.py
# ------------------------------------------------------------------------------------------------
import
math
import
warnings
from
typing
import
Optional
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
torch.autograd
import
Function
from
torch.autograd.function
import
once_differentiable
from
torch.nn.init
import
constant_
,
xavier_uniform_
try
:
from
groundingdino
import
_C
except
:
warnings
.
warn
(
"Failed to load custom C++ ops. Running on CPU mode Only!"
)
# helpers
def
_is_power_of_2
(
n
):
if
(
not
isinstance
(
n
,
int
))
or
(
n
<
0
):
raise
ValueError
(
"invalid input for _is_power_of_2: {} (type: {})"
.
format
(
n
,
type
(
n
)))
return
(
n
&
(
n
-
1
)
==
0
)
and
n
!=
0
class
MultiScaleDeformableAttnFunction
(
Function
):
@
staticmethod
def
forward
(
ctx
,
value
,
value_spatial_shapes
,
value_level_start_index
,
sampling_locations
,
attention_weights
,
im2col_step
,
):
ctx
.
im2col_step
=
im2col_step
output
=
_C
.
ms_deform_attn_forward
(
value
,
value_spatial_shapes
,
value_level_start_index
,
sampling_locations
,
attention_weights
,
ctx
.
im2col_step
,
)
ctx
.
save_for_backward
(
value
,
value_spatial_shapes
,
value_level_start_index
,
sampling_locations
,
attention_weights
,
)
return
output
@
staticmethod
@
once_differentiable
def
backward
(
ctx
,
grad_output
):
(
value
,
value_spatial_shapes
,
value_level_start_index
,
sampling_locations
,
attention_weights
,
)
=
ctx
.
saved_tensors
grad_value
,
grad_sampling_loc
,
grad_attn_weight
=
_C
.
ms_deform_attn_backward
(
value
,
value_spatial_shapes
,
value_level_start_index
,
sampling_locations
,
attention_weights
,
grad_output
,
ctx
.
im2col_step
,
)
return
grad_value
,
None
,
None
,
grad_sampling_loc
,
grad_attn_weight
,
None
# =================================================================
# 新增这个 symbolic 方法,用于 ONNX 导出
# =================================================================
@
staticmethod
def
symbolic
(
g
,
value
,
value_spatial_shapes
,
value_level_start_index
,
sampling_locations
,
attention_weights
,
im2col_step
,
):
"""
当调用 torch.onnx.export 时,遇到此 Function 会执行这个逻辑。
"""
# g.op 表示在 ONNX 图中插入一个节点
# "custom::ms_deform_attn" 是我们给它起的唯一名字 (Domain::OpName)
# im2col_step 是一个标量整数,在 ONNX 中作为 attribute 传入,后缀 _i 表示 int
return
g
.
op
(
"custom::ms_deform_attn"
,
value
,
value_spatial_shapes
,
value_level_start_index
,
sampling_locations
,
attention_weights
,
im2col_step_i
=
im2col_step
,
)
def
multi_scale_deformable_attn_pytorch
(
value
:
torch
.
Tensor
,
value_spatial_shapes
:
torch
.
Tensor
,
sampling_locations
:
torch
.
Tensor
,
attention_weights
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
bs
,
_
,
num_heads
,
embed_dims
=
value
.
shape
_
,
num_queries
,
num_heads
,
num_levels
,
num_points
,
_
=
sampling_locations
.
shape
value_list
=
value
.
split
([
H_
*
W_
for
H_
,
W_
in
value_spatial_shapes
],
dim
=
1
)
sampling_grids
=
2
*
sampling_locations
-
1
sampling_value_list
=
[]
for
level
,
(
H_
,
W_
)
in
enumerate
(
value_spatial_shapes
):
# bs, H_*W_, num_heads, embed_dims ->
# bs, H_*W_, num_heads*embed_dims ->
# bs, num_heads*embed_dims, H_*W_ ->
# bs*num_heads, embed_dims, H_, W_
value_l_
=
(
value_list
[
level
].
flatten
(
2
).
transpose
(
1
,
2
).
reshape
(
bs
*
num_heads
,
embed_dims
,
H_
,
W_
)
)
# bs, num_queries, num_heads, num_points, 2 ->
# bs, num_heads, num_queries, num_points, 2 ->
# bs*num_heads, num_queries, num_points, 2
sampling_grid_l_
=
sampling_grids
[:,
:,
:,
level
].
transpose
(
1
,
2
).
flatten
(
0
,
1
)
# bs*num_heads, embed_dims, num_queries, num_points
sampling_value_l_
=
F
.
grid_sample
(
value_l_
,
sampling_grid_l_
,
mode
=
"bilinear"
,
padding_mode
=
"zeros"
,
align_corners
=
False
)
sampling_value_list
.
append
(
sampling_value_l_
)
# (bs, num_queries, num_heads, num_levels, num_points) ->
# (bs, num_heads, num_queries, num_levels, num_points) ->
# (bs, num_heads, 1, num_queries, num_levels*num_points)
attention_weights
=
attention_weights
.
transpose
(
1
,
2
).
reshape
(
bs
*
num_heads
,
1
,
num_queries
,
num_levels
*
num_points
)
output
=
(
(
torch
.
stack
(
sampling_value_list
,
dim
=-
2
).
flatten
(
-
2
)
*
attention_weights
)
.
sum
(
-
1
)
.
view
(
bs
,
num_heads
*
embed_dims
,
num_queries
)
)
return
output
.
transpose
(
1
,
2
).
contiguous
()
class
MultiScaleDeformableAttention
(
nn
.
Module
):
"""Multi-Scale Deformable Attention Module used in Deformable-DETR
`Deformable DETR: Deformable Transformers for End-to-End Object Detection.
<https://arxiv.org/pdf/2010.04159.pdf>`_.
Args:
embed_dim (int): The embedding dimension of Attention. Default: 256.
num_heads (int): The number of attention heads. Default: 8.
num_levels (int): The number of feature map used in Attention. Default: 4.
num_points (int): The number of sampling points for each query
in each head. Default: 4.
img2col_steps (int): The step used in image_to_column. Defualt: 64.
dropout (float): Dropout layer used in output. Default: 0.1.
batch_first (bool): if ``True``, then the input and output tensor will be
provided as `(bs, n, embed_dim)`. Default: False. `(n, bs, embed_dim)`
"""
def
__init__
(
self
,
embed_dim
:
int
=
256
,
num_heads
:
int
=
8
,
num_levels
:
int
=
4
,
num_points
:
int
=
4
,
img2col_step
:
int
=
64
,
batch_first
:
bool
=
False
,
):
super
().
__init__
()
if
embed_dim
%
num_heads
!=
0
:
raise
ValueError
(
"embed_dim must be divisible by num_heads, but got {} and {}"
.
format
(
embed_dim
,
num_heads
)
)
head_dim
=
embed_dim
//
num_heads
self
.
batch_first
=
batch_first
if
not
_is_power_of_2
(
head_dim
):
warnings
.
warn
(
"""
You'd better set d_model in MSDeformAttn to make sure that
each dim of the attention head a power of 2, which is more efficient.
"""
)
self
.
im2col_step
=
img2col_step
self
.
embed_dim
=
embed_dim
self
.
num_heads
=
num_heads
self
.
num_levels
=
num_levels
self
.
num_points
=
num_points
self
.
sampling_offsets
=
nn
.
Linear
(
embed_dim
,
num_heads
*
num_levels
*
num_points
*
2
)
self
.
attention_weights
=
nn
.
Linear
(
embed_dim
,
num_heads
*
num_levels
*
num_points
)
self
.
value_proj
=
nn
.
Linear
(
embed_dim
,
embed_dim
)
self
.
output_proj
=
nn
.
Linear
(
embed_dim
,
embed_dim
)
self
.
init_weights
()
def
_reset_parameters
(
self
):
return
self
.
init_weights
()
def
init_weights
(
self
):
"""
Default initialization for Parameters of Module.
"""
constant_
(
self
.
sampling_offsets
.
weight
.
data
,
0.0
)
thetas
=
torch
.
arange
(
self
.
num_heads
,
dtype
=
torch
.
float32
)
*
(
2.0
*
math
.
pi
/
self
.
num_heads
)
grid_init
=
torch
.
stack
([
thetas
.
cos
(),
thetas
.
sin
()],
-
1
)
grid_init
=
(
(
grid_init
/
grid_init
.
abs
().
max
(
-
1
,
keepdim
=
True
)[
0
])
.
view
(
self
.
num_heads
,
1
,
1
,
2
)
.
repeat
(
1
,
self
.
num_levels
,
self
.
num_points
,
1
)
)
for
i
in
range
(
self
.
num_points
):
grid_init
[:,
:,
i
,
:]
*=
i
+
1
with
torch
.
no_grad
():
self
.
sampling_offsets
.
bias
=
nn
.
Parameter
(
grid_init
.
view
(
-
1
))
constant_
(
self
.
attention_weights
.
weight
.
data
,
0.0
)
constant_
(
self
.
attention_weights
.
bias
.
data
,
0.0
)
xavier_uniform_
(
self
.
value_proj
.
weight
.
data
)
constant_
(
self
.
value_proj
.
bias
.
data
,
0.0
)
xavier_uniform_
(
self
.
output_proj
.
weight
.
data
)
constant_
(
self
.
output_proj
.
bias
.
data
,
0.0
)
def
freeze_sampling_offsets
(
self
):
print
(
"Freeze sampling offsets"
)
self
.
sampling_offsets
.
weight
.
requires_grad
=
False
self
.
sampling_offsets
.
bias
.
requires_grad
=
False
def
freeze_attention_weights
(
self
):
print
(
"Freeze attention weights"
)
self
.
attention_weights
.
weight
.
requires_grad
=
False
self
.
attention_weights
.
bias
.
requires_grad
=
False
def
forward
(
self
,
query
:
torch
.
Tensor
,
key
:
Optional
[
torch
.
Tensor
]
=
None
,
value
:
Optional
[
torch
.
Tensor
]
=
None
,
query_pos
:
Optional
[
torch
.
Tensor
]
=
None
,
key_padding_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
reference_points
:
Optional
[
torch
.
Tensor
]
=
None
,
spatial_shapes
:
Optional
[
torch
.
Tensor
]
=
None
,
level_start_index
:
Optional
[
torch
.
Tensor
]
=
None
,
**
kwargs
)
->
torch
.
Tensor
:
"""Forward Function of MultiScaleDeformableAttention
Args:
query (torch.Tensor): Query embeddings with shape
`(num_query, bs, embed_dim)`
key (torch.Tensor): Key embeddings with shape
`(num_key, bs, embed_dim)`
value (torch.Tensor): Value embeddings with shape
`(num_key, bs, embed_dim)`
query_pos (torch.Tensor): The position embedding for `query`. Default: None.
key_padding_mask (torch.Tensor): ByteTensor for `query`, with shape `(bs, num_key)`,
indicating which elements within `key` to be ignored in attention.
reference_points (torch.Tensor): The normalized reference points
with shape `(bs, num_query, num_levels, 2)`,
all elements is range in [0, 1], top-left (0, 0),
bottom-right (1, 1), including padding are.
or `(N, Length_{query}, num_levels, 4)`, add additional
two dimensions `(h, w)` to form reference boxes.
spatial_shapes (torch.Tensor): Spatial shape of features in different levels.
With shape `(num_levels, 2)`, last dimension represents `(h, w)`.
level_start_index (torch.Tensor): The start index of each level. A tensor with
shape `(num_levels, )` which can be represented as
`[0, h_0 * w_0, h_0 * w_0 + h_1 * w_1, ...]`.
Returns:
torch.Tensor: forward results with shape `(num_query, bs, embed_dim)`
"""
if
value
is
None
:
value
=
query
if
query_pos
is
not
None
:
query
=
query
+
query_pos
if
not
self
.
batch_first
:
# change to (bs, num_query ,embed_dims)
query
=
query
.
permute
(
1
,
0
,
2
)
value
=
value
.
permute
(
1
,
0
,
2
)
bs
,
num_query
,
_
=
query
.
shape
bs
,
num_value
,
_
=
value
.
shape
assert
(
spatial_shapes
[:,
0
]
*
spatial_shapes
[:,
1
]).
sum
()
==
num_value
value
=
self
.
value_proj
(
value
)
if
key_padding_mask
is
not
None
:
value
=
value
.
masked_fill
(
key_padding_mask
[...,
None
],
float
(
0
))
value
=
value
.
view
(
bs
,
num_value
,
self
.
num_heads
,
-
1
)
sampling_offsets
=
self
.
sampling_offsets
(
query
).
view
(
bs
,
num_query
,
self
.
num_heads
,
self
.
num_levels
,
self
.
num_points
,
2
)
attention_weights
=
self
.
attention_weights
(
query
).
view
(
bs
,
num_query
,
self
.
num_heads
,
self
.
num_levels
*
self
.
num_points
)
attention_weights
=
attention_weights
.
softmax
(
-
1
)
attention_weights
=
attention_weights
.
view
(
bs
,
num_query
,
self
.
num_heads
,
self
.
num_levels
,
self
.
num_points
,
)
# bs, num_query, num_heads, num_levels, num_points, 2
if
reference_points
.
shape
[
-
1
]
==
2
:
offset_normalizer
=
torch
.
stack
([
spatial_shapes
[...,
1
],
spatial_shapes
[...,
0
]],
-
1
)
sampling_locations
=
(
reference_points
[:,
:,
None
,
:,
None
,
:]
+
sampling_offsets
/
offset_normalizer
[
None
,
None
,
None
,
:,
None
,
:]
)
elif
reference_points
.
shape
[
-
1
]
==
4
:
sampling_locations
=
(
reference_points
[:,
:,
None
,
:,
None
,
:
2
]
+
sampling_offsets
/
self
.
num_points
*
reference_points
[:,
:,
None
,
:,
None
,
2
:]
*
0.5
)
else
:
raise
ValueError
(
"Last dim of reference_points must be 2 or 4, but get {} instead."
.
format
(
reference_points
.
shape
[
-
1
]
)
)
if
torch
.
cuda
.
is_available
()
and
value
.
is_cuda
:
halffloat
=
False
if
value
.
dtype
==
torch
.
float16
:
halffloat
=
True
value
=
value
.
float
()
sampling_locations
=
sampling_locations
.
float
()
attention_weights
=
attention_weights
.
float
()
output
=
MultiScaleDeformableAttnFunction
.
apply
(
value
,
spatial_shapes
,
level_start_index
,
sampling_locations
,
attention_weights
,
self
.
im2col_step
,
)
if
halffloat
:
output
=
output
.
half
()
else
:
output
=
multi_scale_deformable_attn_pytorch
(
value
,
spatial_shapes
,
sampling_locations
,
attention_weights
)
output
=
self
.
output_proj
(
output
)
if
not
self
.
batch_first
:
output
=
output
.
permute
(
1
,
0
,
2
)
return
output
def
create_dummy_class
(
klass
,
dependency
,
message
=
""
):
"""
When a dependency of a class is not available, create a dummy class which throws ImportError
when used.
Args:
klass (str): name of the class.
dependency (str): name of the dependency.
message: extra message to print
Returns:
class: a class object
"""
err
=
"Cannot import '{}', therefore '{}' is not available."
.
format
(
dependency
,
klass
)
if
message
:
err
=
err
+
" "
+
message
class
_DummyMetaClass
(
type
):
# throw error on class attribute access
def
__getattr__
(
_
,
__
):
# noqa: B902
raise
ImportError
(
err
)
class
_Dummy
(
object
,
metaclass
=
_DummyMetaClass
):
# throw error on constructor
def
__init__
(
self
,
*
args
,
**
kwargs
):
raise
ImportError
(
err
)
return
_Dummy
def
create_dummy_func
(
func
,
dependency
,
message
=
""
):
"""
When a dependency of a function is not available, create a dummy function which throws
ImportError when used.
Args:
func (str): name of the function.
dependency (str or list[str]): name(s) of the dependency.
message: extra message to print
Returns:
function: a function object
"""
err
=
"Cannot import '{}', therefore '{}' is not available."
.
format
(
dependency
,
func
)
if
message
:
err
=
err
+
" "
+
message
if
isinstance
(
dependency
,
(
list
,
tuple
)):
dependency
=
","
.
join
(
dependency
)
def
_dummy
(
*
args
,
**
kwargs
):
raise
ImportError
(
err
)
return
_dummy
groundingdino/models/GroundingDINO/transformer copy.py
0 → 100644
View file @
34e4011b
# ------------------------------------------------------------------------
# Grounding DINO
# url: https://github.com/IDEA-Research/GroundingDINO
# Copyright (c) 2023 IDEA. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------
# DINO
# Copyright (c) 2022 IDEA. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------
# Conditional DETR Transformer class.
# Copyright (c) 2021 Microsoft. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------
# Modified from DETR (https://github.com/facebookresearch/detr)
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
# ------------------------------------------------------------------------
from
typing
import
Optional
import
torch
import
torch.utils.checkpoint
as
checkpoint
from
torch
import
Tensor
,
nn
from
groundingdino.util.misc
import
inverse_sigmoid
from
.fuse_modules
import
BiAttentionBlock
from
.ms_deform_attn
import
MultiScaleDeformableAttention
as
MSDeformAttn
from
.transformer_vanilla
import
TransformerEncoderLayer
from
.utils
import
(
MLP
,
_get_activation_fn
,
_get_clones
,
gen_encoder_output_proposals
,
gen_sineembed_for_position
,
get_sine_pos_embed
,
)
class
Transformer
(
nn
.
Module
):
def
__init__
(
self
,
d_model
=
256
,
nhead
=
8
,
num_queries
=
300
,
num_encoder_layers
=
6
,
num_unicoder_layers
=
0
,
num_decoder_layers
=
6
,
dim_feedforward
=
2048
,
dropout
=
0.0
,
activation
=
"relu"
,
normalize_before
=
False
,
return_intermediate_dec
=
False
,
query_dim
=
4
,
num_patterns
=
0
,
# for deformable encoder
num_feature_levels
=
1
,
enc_n_points
=
4
,
dec_n_points
=
4
,
# init query
learnable_tgt_init
=
False
,
# two stage
two_stage_type
=
"no"
,
# ['no', 'standard', 'early', 'combine', 'enceachlayer', 'enclayer1']
embed_init_tgt
=
False
,
# for text
use_text_enhancer
=
False
,
use_fusion_layer
=
False
,
use_checkpoint
=
False
,
use_transformer_ckpt
=
False
,
use_text_cross_attention
=
False
,
text_dropout
=
0.1
,
fusion_dropout
=
0.1
,
fusion_droppath
=
0.0
,
):
super
().
__init__
()
self
.
num_feature_levels
=
num_feature_levels
self
.
num_encoder_layers
=
num_encoder_layers
self
.
num_unicoder_layers
=
num_unicoder_layers
self
.
num_decoder_layers
=
num_decoder_layers
self
.
num_queries
=
num_queries
assert
query_dim
==
4
# choose encoder layer type
encoder_layer
=
DeformableTransformerEncoderLayer
(
d_model
,
dim_feedforward
,
dropout
,
activation
,
num_feature_levels
,
nhead
,
enc_n_points
)
if
use_text_enhancer
:
text_enhance_layer
=
TransformerEncoderLayer
(
d_model
=
d_model
,
nhead
=
nhead
//
2
,
dim_feedforward
=
dim_feedforward
//
2
,
dropout
=
text_dropout
,
)
else
:
text_enhance_layer
=
None
if
use_fusion_layer
:
feature_fusion_layer
=
BiAttentionBlock
(
v_dim
=
d_model
,
l_dim
=
d_model
,
embed_dim
=
dim_feedforward
//
2
,
num_heads
=
nhead
//
2
,
dropout
=
fusion_dropout
,
drop_path
=
fusion_droppath
,
)
else
:
feature_fusion_layer
=
None
encoder_norm
=
nn
.
LayerNorm
(
d_model
)
if
normalize_before
else
None
assert
encoder_norm
is
None
self
.
encoder
=
TransformerEncoder
(
encoder_layer
,
num_encoder_layers
,
d_model
=
d_model
,
num_queries
=
num_queries
,
text_enhance_layer
=
text_enhance_layer
,
feature_fusion_layer
=
feature_fusion_layer
,
use_checkpoint
=
use_checkpoint
,
use_transformer_ckpt
=
use_transformer_ckpt
,
)
# choose decoder layer type
decoder_layer
=
DeformableTransformerDecoderLayer
(
d_model
,
dim_feedforward
,
dropout
,
activation
,
num_feature_levels
,
nhead
,
dec_n_points
,
use_text_cross_attention
=
use_text_cross_attention
,
)
decoder_norm
=
nn
.
LayerNorm
(
d_model
)
self
.
decoder
=
TransformerDecoder
(
decoder_layer
,
num_decoder_layers
,
decoder_norm
,
return_intermediate
=
return_intermediate_dec
,
d_model
=
d_model
,
query_dim
=
query_dim
,
num_feature_levels
=
num_feature_levels
,
)
self
.
d_model
=
d_model
self
.
nhead
=
nhead
self
.
dec_layers
=
num_decoder_layers
self
.
num_queries
=
num_queries
# useful for single stage model only
self
.
num_patterns
=
num_patterns
if
not
isinstance
(
num_patterns
,
int
):
Warning
(
"num_patterns should be int but {}"
.
format
(
type
(
num_patterns
)))
self
.
num_patterns
=
0
if
num_feature_levels
>
1
:
if
self
.
num_encoder_layers
>
0
:
self
.
level_embed
=
nn
.
Parameter
(
torch
.
Tensor
(
num_feature_levels
,
d_model
))
else
:
self
.
level_embed
=
None
self
.
learnable_tgt_init
=
learnable_tgt_init
assert
learnable_tgt_init
,
"why not learnable_tgt_init"
self
.
embed_init_tgt
=
embed_init_tgt
if
(
two_stage_type
!=
"no"
and
embed_init_tgt
)
or
(
two_stage_type
==
"no"
):
self
.
tgt_embed
=
nn
.
Embedding
(
self
.
num_queries
,
d_model
)
nn
.
init
.
normal_
(
self
.
tgt_embed
.
weight
.
data
)
else
:
self
.
tgt_embed
=
None
# for two stage
self
.
two_stage_type
=
two_stage_type
assert
two_stage_type
in
[
"no"
,
"standard"
],
"unknown param {} of two_stage_type"
.
format
(
two_stage_type
)
if
two_stage_type
==
"standard"
:
# anchor selection at the output of encoder
self
.
enc_output
=
nn
.
Linear
(
d_model
,
d_model
)
self
.
enc_output_norm
=
nn
.
LayerNorm
(
d_model
)
self
.
two_stage_wh_embedding
=
None
if
two_stage_type
==
"no"
:
self
.
init_ref_points
(
num_queries
)
# init self.refpoint_embed
self
.
enc_out_class_embed
=
None
self
.
enc_out_bbox_embed
=
None
self
.
_reset_parameters
()
def
_reset_parameters
(
self
):
for
p
in
self
.
parameters
():
if
p
.
dim
()
>
1
:
nn
.
init
.
xavier_uniform_
(
p
)
for
m
in
self
.
modules
():
if
isinstance
(
m
,
MSDeformAttn
):
m
.
_reset_parameters
()
if
self
.
num_feature_levels
>
1
and
self
.
level_embed
is
not
None
:
nn
.
init
.
normal_
(
self
.
level_embed
)
def
get_valid_ratio
(
self
,
mask
):
_
,
H
,
W
=
mask
.
shape
valid_H
=
torch
.
sum
(
~
mask
[:,
:,
0
],
1
)
valid_W
=
torch
.
sum
(
~
mask
[:,
0
,
:],
1
)
valid_ratio_h
=
valid_H
.
float
()
/
H
valid_ratio_w
=
valid_W
.
float
()
/
W
valid_ratio
=
torch
.
stack
([
valid_ratio_w
,
valid_ratio_h
],
-
1
)
return
valid_ratio
def
init_ref_points
(
self
,
use_num_queries
):
self
.
refpoint_embed
=
nn
.
Embedding
(
use_num_queries
,
4
)
def
forward
(
self
,
srcs
,
masks
,
refpoint_embed
,
pos_embeds
,
tgt
,
attn_mask
=
None
,
text_dict
=
None
):
"""
Input:
- srcs: List of multi features [bs, ci, hi, wi]
- masks: List of multi masks [bs, hi, wi]
- refpoint_embed: [bs, num_dn, 4]. None in infer
- pos_embeds: List of multi pos embeds [bs, ci, hi, wi]
- tgt: [bs, num_dn, d_model]. None in infer
"""
# prepare input for encoder
src_flatten
=
[]
mask_flatten
=
[]
lvl_pos_embed_flatten
=
[]
spatial_shapes
=
[]
for
lvl
,
(
src
,
mask
,
pos_embed
)
in
enumerate
(
zip
(
srcs
,
masks
,
pos_embeds
)):
bs
,
c
,
h
,
w
=
src
.
shape
spatial_shape
=
(
h
,
w
)
spatial_shapes
.
append
(
spatial_shape
)
src
=
src
.
flatten
(
2
).
transpose
(
1
,
2
)
# bs, hw, c
mask
=
mask
.
flatten
(
1
)
# bs, hw
pos_embed
=
pos_embed
.
flatten
(
2
).
transpose
(
1
,
2
)
# bs, hw, c
if
self
.
num_feature_levels
>
1
and
self
.
level_embed
is
not
None
:
lvl_pos_embed
=
pos_embed
+
self
.
level_embed
[
lvl
].
view
(
1
,
1
,
-
1
)
else
:
lvl_pos_embed
=
pos_embed
lvl_pos_embed_flatten
.
append
(
lvl_pos_embed
)
src_flatten
.
append
(
src
)
mask_flatten
.
append
(
mask
)
src_flatten
=
torch
.
cat
(
src_flatten
,
1
)
# bs, \sum{hxw}, c
mask_flatten
=
torch
.
cat
(
mask_flatten
,
1
)
# bs, \sum{hxw}
lvl_pos_embed_flatten
=
torch
.
cat
(
lvl_pos_embed_flatten
,
1
)
# bs, \sum{hxw}, c
spatial_shapes
=
torch
.
as_tensor
(
spatial_shapes
,
dtype
=
torch
.
long
,
device
=
src_flatten
.
device
)
level_start_index
=
torch
.
cat
(
(
spatial_shapes
.
new_zeros
((
1
,)),
spatial_shapes
.
prod
(
1
).
cumsum
(
0
)[:
-
1
])
)
valid_ratios
=
torch
.
stack
([
self
.
get_valid_ratio
(
m
)
for
m
in
masks
],
1
)
# two stage
enc_topk_proposals
=
enc_refpoint_embed
=
None
#########################################################
# Begin Encoder
#########################################################
memory
,
memory_text
=
self
.
encoder
(
src_flatten
,
pos
=
lvl_pos_embed_flatten
,
level_start_index
=
level_start_index
,
spatial_shapes
=
spatial_shapes
,
valid_ratios
=
valid_ratios
,
key_padding_mask
=
mask_flatten
,
memory_text
=
text_dict
[
"encoded_text"
],
text_attention_mask
=~
text_dict
[
"text_token_mask"
],
# we ~ the mask . False means use the token; True means pad the token
position_ids
=
text_dict
[
"position_ids"
],
text_self_attention_masks
=
text_dict
[
"text_self_attention_masks"
],
)
#########################################################
# End Encoder
# - memory: bs, \sum{hw}, c
# - mask_flatten: bs, \sum{hw}
# - lvl_pos_embed_flatten: bs, \sum{hw}, c
# - enc_intermediate_output: None or (nenc+1, bs, nq, c) or (nenc, bs, nq, c)
# - enc_intermediate_refpoints: None or (nenc+1, bs, nq, c) or (nenc, bs, nq, c)
#########################################################
text_dict
[
"encoded_text"
]
=
memory_text
# if os.environ.get("SHILONG_AMP_INFNAN_DEBUG") == '1':
# if memory.isnan().any() | memory.isinf().any():
# import ipdb; ipdb.set_trace()
if
self
.
two_stage_type
==
"standard"
:
output_memory
,
output_proposals
=
gen_encoder_output_proposals
(
memory
,
mask_flatten
,
spatial_shapes
)
output_memory
=
self
.
enc_output_norm
(
self
.
enc_output
(
output_memory
))
if
text_dict
is
not
None
:
enc_outputs_class_unselected
=
self
.
enc_out_class_embed
(
output_memory
,
text_dict
)
else
:
enc_outputs_class_unselected
=
self
.
enc_out_class_embed
(
output_memory
)
topk_logits
=
enc_outputs_class_unselected
.
max
(
-
1
)[
0
]
enc_outputs_coord_unselected
=
(
self
.
enc_out_bbox_embed
(
output_memory
)
+
output_proposals
)
# (bs, \sum{hw}, 4) unsigmoid
topk
=
self
.
num_queries
topk_proposals
=
torch
.
topk
(
topk_logits
,
topk
,
dim
=
1
)[
1
]
# bs, nq
# gather boxes
refpoint_embed_undetach
=
torch
.
gather
(
enc_outputs_coord_unselected
,
1
,
topk_proposals
.
unsqueeze
(
-
1
).
repeat
(
1
,
1
,
4
)
)
# unsigmoid
refpoint_embed_
=
refpoint_embed_undetach
.
detach
()
init_box_proposal
=
torch
.
gather
(
output_proposals
,
1
,
topk_proposals
.
unsqueeze
(
-
1
).
repeat
(
1
,
1
,
4
)
).
sigmoid
()
# sigmoid
# gather tgt
tgt_undetach
=
torch
.
gather
(
output_memory
,
1
,
topk_proposals
.
unsqueeze
(
-
1
).
repeat
(
1
,
1
,
self
.
d_model
)
)
if
self
.
embed_init_tgt
:
tgt_
=
(
self
.
tgt_embed
.
weight
[:,
None
,
:].
repeat
(
1
,
bs
,
1
).
transpose
(
0
,
1
)
)
# nq, bs, d_model
else
:
tgt_
=
tgt_undetach
.
detach
()
if
refpoint_embed
is
not
None
:
refpoint_embed
=
torch
.
cat
([
refpoint_embed
,
refpoint_embed_
],
dim
=
1
)
tgt
=
torch
.
cat
([
tgt
,
tgt_
],
dim
=
1
)
else
:
refpoint_embed
,
tgt
=
refpoint_embed_
,
tgt_
elif
self
.
two_stage_type
==
"no"
:
tgt_
=
(
self
.
tgt_embed
.
weight
[:,
None
,
:].
repeat
(
1
,
bs
,
1
).
transpose
(
0
,
1
)
)
# nq, bs, d_model
refpoint_embed_
=
(
self
.
refpoint_embed
.
weight
[:,
None
,
:].
repeat
(
1
,
bs
,
1
).
transpose
(
0
,
1
)
)
# nq, bs, 4
if
refpoint_embed
is
not
None
:
refpoint_embed
=
torch
.
cat
([
refpoint_embed
,
refpoint_embed_
],
dim
=
1
)
tgt
=
torch
.
cat
([
tgt
,
tgt_
],
dim
=
1
)
else
:
refpoint_embed
,
tgt
=
refpoint_embed_
,
tgt_
if
self
.
num_patterns
>
0
:
tgt_embed
=
tgt
.
repeat
(
1
,
self
.
num_patterns
,
1
)
refpoint_embed
=
refpoint_embed
.
repeat
(
1
,
self
.
num_patterns
,
1
)
tgt_pat
=
self
.
patterns
.
weight
[
None
,
:,
:].
repeat_interleave
(
self
.
num_queries
,
1
)
# 1, n_q*n_pat, d_model
tgt
=
tgt_embed
+
tgt_pat
init_box_proposal
=
refpoint_embed_
.
sigmoid
()
else
:
raise
NotImplementedError
(
"unknown two_stage_type {}"
.
format
(
self
.
two_stage_type
))
#########################################################
# End preparing tgt
# - tgt: bs, NQ, d_model
# - refpoint_embed(unsigmoid): bs, NQ, d_model
#########################################################
#########################################################
# Begin Decoder
#########################################################
hs
,
references
=
self
.
decoder
(
tgt
=
tgt
.
transpose
(
0
,
1
),
memory
=
memory
.
transpose
(
0
,
1
),
memory_key_padding_mask
=
mask_flatten
,
pos
=
lvl_pos_embed_flatten
.
transpose
(
0
,
1
),
refpoints_unsigmoid
=
refpoint_embed
.
transpose
(
0
,
1
),
level_start_index
=
level_start_index
,
spatial_shapes
=
spatial_shapes
,
valid_ratios
=
valid_ratios
,
tgt_mask
=
attn_mask
,
memory_text
=
text_dict
[
"encoded_text"
],
text_attention_mask
=~
text_dict
[
"text_token_mask"
],
# we ~ the mask . False means use the token; True means pad the token
)
#########################################################
# End Decoder
# hs: n_dec, bs, nq, d_model
# references: n_dec+1, bs, nq, query_dim
#########################################################
#########################################################
# Begin postprocess
#########################################################
if
self
.
two_stage_type
==
"standard"
:
hs_enc
=
tgt_undetach
.
unsqueeze
(
0
)
ref_enc
=
refpoint_embed_undetach
.
sigmoid
().
unsqueeze
(
0
)
else
:
hs_enc
=
ref_enc
=
None
#########################################################
# End postprocess
# hs_enc: (n_enc+1, bs, nq, d_model) or (1, bs, nq, d_model) or (n_enc, bs, nq, d_model) or None
# ref_enc: (n_enc+1, bs, nq, query_dim) or (1, bs, nq, query_dim) or (n_enc, bs, nq, d_model) or None
#########################################################
return
hs
,
references
,
hs_enc
,
ref_enc
,
init_box_proposal
# hs: (n_dec, bs, nq, d_model)
# references: sigmoid coordinates. (n_dec+1, bs, bq, 4)
# hs_enc: (n_enc+1, bs, nq, d_model) or (1, bs, nq, d_model) or None
# ref_enc: sigmoid coordinates. \
# (n_enc+1, bs, nq, query_dim) or (1, bs, nq, query_dim) or None
class
TransformerEncoder
(
nn
.
Module
):
def
__init__
(
self
,
encoder_layer
,
num_layers
,
d_model
=
256
,
num_queries
=
300
,
enc_layer_share
=
False
,
text_enhance_layer
=
None
,
feature_fusion_layer
=
None
,
use_checkpoint
=
False
,
use_transformer_ckpt
=
False
,
):
"""_summary_
Args:
encoder_layer (_type_): _description_
num_layers (_type_): _description_
norm (_type_, optional): _description_. Defaults to None.
d_model (int, optional): _description_. Defaults to 256.
num_queries (int, optional): _description_. Defaults to 300.
enc_layer_share (bool, optional): _description_. Defaults to False.
"""
super
().
__init__
()
# prepare layers
self
.
layers
=
[]
self
.
text_layers
=
[]
self
.
fusion_layers
=
[]
if
num_layers
>
0
:
self
.
layers
=
_get_clones
(
encoder_layer
,
num_layers
,
layer_share
=
enc_layer_share
)
if
text_enhance_layer
is
not
None
:
self
.
text_layers
=
_get_clones
(
text_enhance_layer
,
num_layers
,
layer_share
=
enc_layer_share
)
if
feature_fusion_layer
is
not
None
:
self
.
fusion_layers
=
_get_clones
(
feature_fusion_layer
,
num_layers
,
layer_share
=
enc_layer_share
)
else
:
self
.
layers
=
[]
del
encoder_layer
if
text_enhance_layer
is
not
None
:
self
.
text_layers
=
[]
del
text_enhance_layer
if
feature_fusion_layer
is
not
None
:
self
.
fusion_layers
=
[]
del
feature_fusion_layer
self
.
query_scale
=
None
self
.
num_queries
=
num_queries
self
.
num_layers
=
num_layers
self
.
d_model
=
d_model
self
.
use_checkpoint
=
use_checkpoint
self
.
use_transformer_ckpt
=
use_transformer_ckpt
@
staticmethod
def
get_reference_points
(
spatial_shapes
,
valid_ratios
,
device
):
reference_points_list
=
[]
for
lvl
,
(
H_
,
W_
)
in
enumerate
(
spatial_shapes
):
ref_y
,
ref_x
=
torch
.
meshgrid
(
torch
.
linspace
(
0.5
,
H_
-
0.5
,
H_
,
dtype
=
torch
.
float32
,
device
=
device
),
torch
.
linspace
(
0.5
,
W_
-
0.5
,
W_
,
dtype
=
torch
.
float32
,
device
=
device
),
)
ref_y
=
ref_y
.
reshape
(
-
1
)[
None
]
/
(
valid_ratios
[:,
None
,
lvl
,
1
]
*
H_
)
ref_x
=
ref_x
.
reshape
(
-
1
)[
None
]
/
(
valid_ratios
[:,
None
,
lvl
,
0
]
*
W_
)
ref
=
torch
.
stack
((
ref_x
,
ref_y
),
-
1
)
reference_points_list
.
append
(
ref
)
reference_points
=
torch
.
cat
(
reference_points_list
,
1
)
reference_points
=
reference_points
[:,
:,
None
]
*
valid_ratios
[:,
None
]
return
reference_points
def
forward
(
self
,
# for images
src
:
Tensor
,
pos
:
Tensor
,
spatial_shapes
:
Tensor
,
level_start_index
:
Tensor
,
valid_ratios
:
Tensor
,
key_padding_mask
:
Tensor
,
# for texts
memory_text
:
Tensor
=
None
,
text_attention_mask
:
Tensor
=
None
,
pos_text
:
Tensor
=
None
,
text_self_attention_masks
:
Tensor
=
None
,
position_ids
:
Tensor
=
None
,
):
"""
Input:
- src: [bs, sum(hi*wi), 256]
- pos: pos embed for src. [bs, sum(hi*wi), 256]
- spatial_shapes: h,w of each level [num_level, 2]
- level_start_index: [num_level] start point of level in sum(hi*wi).
- valid_ratios: [bs, num_level, 2]
- key_padding_mask: [bs, sum(hi*wi)]
- memory_text: bs, n_text, 256
- text_attention_mask: bs, n_text
False for no padding; True for padding
- pos_text: bs, n_text, 256
- position_ids: bs, n_text
Intermedia:
- reference_points: [bs, sum(hi*wi), num_level, 2]
Outpus:
- output: [bs, sum(hi*wi), 256]
"""
output
=
src
# preparation and reshape
if
self
.
num_layers
>
0
:
reference_points
=
self
.
get_reference_points
(
spatial_shapes
,
valid_ratios
,
device
=
src
.
device
)
if
self
.
text_layers
:
# generate pos_text
bs
,
n_text
,
text_dim
=
memory_text
.
shape
if
pos_text
is
None
and
position_ids
is
None
:
pos_text
=
(
torch
.
arange
(
n_text
,
device
=
memory_text
.
device
)
.
float
()
.
unsqueeze
(
0
)
.
unsqueeze
(
-
1
)
.
repeat
(
bs
,
1
,
1
)
)
pos_text
=
get_sine_pos_embed
(
pos_text
,
num_pos_feats
=
256
,
exchange_xy
=
False
)
if
position_ids
is
not
None
:
pos_text
=
get_sine_pos_embed
(
position_ids
[...,
None
],
num_pos_feats
=
256
,
exchange_xy
=
False
)
# main process
for
layer_id
,
layer
in
enumerate
(
self
.
layers
):
# if output.isnan().any() or memory_text.isnan().any():
# if os.environ.get('IPDB_SHILONG_DEBUG', None) == 'INFO':
# import ipdb; ipdb.set_trace()
if
self
.
fusion_layers
:
if
self
.
use_checkpoint
:
output
,
memory_text
=
checkpoint
.
checkpoint
(
self
.
fusion_layers
[
layer_id
],
output
,
memory_text
,
key_padding_mask
,
text_attention_mask
,
)
else
:
output
,
memory_text
=
self
.
fusion_layers
[
layer_id
](
v
=
output
,
l
=
memory_text
,
attention_mask_v
=
key_padding_mask
,
attention_mask_l
=
text_attention_mask
,
)
if
self
.
text_layers
:
memory_text
=
self
.
text_layers
[
layer_id
](
src
=
memory_text
.
transpose
(
0
,
1
),
src_mask
=~
text_self_attention_masks
,
# note we use ~ for mask here
src_key_padding_mask
=
text_attention_mask
,
pos
=
(
pos_text
.
transpose
(
0
,
1
)
if
pos_text
is
not
None
else
None
),
).
transpose
(
0
,
1
)
# main process
if
self
.
use_transformer_ckpt
:
output
=
checkpoint
.
checkpoint
(
layer
,
output
,
pos
,
reference_points
,
spatial_shapes
,
level_start_index
,
key_padding_mask
,
)
else
:
output
=
layer
(
src
=
output
,
pos
=
pos
,
reference_points
=
reference_points
,
spatial_shapes
=
spatial_shapes
,
level_start_index
=
level_start_index
,
key_padding_mask
=
key_padding_mask
,
)
return
output
,
memory_text
class
TransformerDecoder
(
nn
.
Module
):
def
__init__
(
self
,
decoder_layer
,
num_layers
,
norm
=
None
,
return_intermediate
=
False
,
d_model
=
256
,
query_dim
=
4
,
num_feature_levels
=
1
,
):
super
().
__init__
()
if
num_layers
>
0
:
self
.
layers
=
_get_clones
(
decoder_layer
,
num_layers
)
else
:
self
.
layers
=
[]
self
.
num_layers
=
num_layers
self
.
norm
=
norm
self
.
return_intermediate
=
return_intermediate
assert
return_intermediate
,
"support return_intermediate only"
self
.
query_dim
=
query_dim
assert
query_dim
in
[
2
,
4
],
"query_dim should be 2/4 but {}"
.
format
(
query_dim
)
self
.
num_feature_levels
=
num_feature_levels
self
.
ref_point_head
=
MLP
(
query_dim
//
2
*
d_model
,
d_model
,
d_model
,
2
)
self
.
query_pos_sine_scale
=
None
self
.
query_scale
=
None
self
.
bbox_embed
=
None
self
.
class_embed
=
None
self
.
d_model
=
d_model
self
.
ref_anchor_head
=
None
def
forward
(
self
,
tgt
,
memory
,
tgt_mask
:
Optional
[
Tensor
]
=
None
,
memory_mask
:
Optional
[
Tensor
]
=
None
,
tgt_key_padding_mask
:
Optional
[
Tensor
]
=
None
,
memory_key_padding_mask
:
Optional
[
Tensor
]
=
None
,
pos
:
Optional
[
Tensor
]
=
None
,
refpoints_unsigmoid
:
Optional
[
Tensor
]
=
None
,
# num_queries, bs, 2
# for memory
level_start_index
:
Optional
[
Tensor
]
=
None
,
# num_levels
spatial_shapes
:
Optional
[
Tensor
]
=
None
,
# bs, num_levels, 2
valid_ratios
:
Optional
[
Tensor
]
=
None
,
# for text
memory_text
:
Optional
[
Tensor
]
=
None
,
text_attention_mask
:
Optional
[
Tensor
]
=
None
,
):
"""
Input:
- tgt: nq, bs, d_model
- memory: hw, bs, d_model
- pos: hw, bs, d_model
- refpoints_unsigmoid: nq, bs, 2/4
- valid_ratios/spatial_shapes: bs, nlevel, 2
"""
output
=
tgt
intermediate
=
[]
reference_points
=
refpoints_unsigmoid
.
sigmoid
()
ref_points
=
[
reference_points
]
for
layer_id
,
layer
in
enumerate
(
self
.
layers
):
if
reference_points
.
shape
[
-
1
]
==
4
:
reference_points_input
=
(
reference_points
[:,
:,
None
]
*
torch
.
cat
([
valid_ratios
,
valid_ratios
],
-
1
)[
None
,
:]
)
# nq, bs, nlevel, 4
else
:
assert
reference_points
.
shape
[
-
1
]
==
2
reference_points_input
=
reference_points
[:,
:,
None
]
*
valid_ratios
[
None
,
:]
query_sine_embed
=
gen_sineembed_for_position
(
reference_points_input
[:,
:,
0
,
:]
)
# nq, bs, 256*2
# conditional query
raw_query_pos
=
self
.
ref_point_head
(
query_sine_embed
)
# nq, bs, 256
pos_scale
=
self
.
query_scale
(
output
)
if
self
.
query_scale
is
not
None
else
1
query_pos
=
pos_scale
*
raw_query_pos
# if os.environ.get("SHILONG_AMP_INFNAN_DEBUG") == '1':
# if query_pos.isnan().any() | query_pos.isinf().any():
# import ipdb; ipdb.set_trace()
# main process
output
=
layer
(
tgt
=
output
,
tgt_query_pos
=
query_pos
,
tgt_query_sine_embed
=
query_sine_embed
,
tgt_key_padding_mask
=
tgt_key_padding_mask
,
tgt_reference_points
=
reference_points_input
,
memory_text
=
memory_text
,
text_attention_mask
=
text_attention_mask
,
memory
=
memory
,
memory_key_padding_mask
=
memory_key_padding_mask
,
memory_level_start_index
=
level_start_index
,
memory_spatial_shapes
=
spatial_shapes
,
memory_pos
=
pos
,
self_attn_mask
=
tgt_mask
,
cross_attn_mask
=
memory_mask
,
)
if
output
.
isnan
().
any
()
|
output
.
isinf
().
any
():
print
(
f
"output layer_id
{
layer_id
}
is nan"
)
try
:
num_nan
=
output
.
isnan
().
sum
().
item
()
num_inf
=
output
.
isinf
().
sum
().
item
()
print
(
f
"num_nan
{
num_nan
}
, num_inf
{
num_inf
}
"
)
except
Exception
as
e
:
print
(
e
)
# if os.environ.get("SHILONG_AMP_INFNAN_DEBUG") == '1':
# import ipdb; ipdb.set_trace()
# iter update
if
self
.
bbox_embed
is
not
None
:
# box_holder = self.bbox_embed(output)
# box_holder[..., :self.query_dim] += inverse_sigmoid(reference_points)
# new_reference_points = box_holder[..., :self.query_dim].sigmoid()
reference_before_sigmoid
=
inverse_sigmoid
(
reference_points
)
delta_unsig
=
self
.
bbox_embed
[
layer_id
](
output
)
outputs_unsig
=
delta_unsig
+
reference_before_sigmoid
new_reference_points
=
outputs_unsig
.
sigmoid
()
reference_points
=
new_reference_points
.
detach
()
# if layer_id != self.num_layers - 1:
ref_points
.
append
(
new_reference_points
)
intermediate
.
append
(
self
.
norm
(
output
))
return
[
[
itm_out
.
transpose
(
0
,
1
)
for
itm_out
in
intermediate
],
[
itm_refpoint
.
transpose
(
0
,
1
)
for
itm_refpoint
in
ref_points
],
]
class
DeformableTransformerEncoderLayer
(
nn
.
Module
):
def
__init__
(
self
,
d_model
=
256
,
d_ffn
=
1024
,
dropout
=
0.1
,
activation
=
"relu"
,
n_levels
=
4
,
n_heads
=
8
,
n_points
=
4
,
):
super
().
__init__
()
# self attention
self
.
self_attn
=
MSDeformAttn
(
embed_dim
=
d_model
,
num_levels
=
n_levels
,
num_heads
=
n_heads
,
num_points
=
n_points
,
batch_first
=
True
,
)
self
.
dropout1
=
nn
.
Dropout
(
dropout
)
self
.
norm1
=
nn
.
LayerNorm
(
d_model
)
# ffn
self
.
linear1
=
nn
.
Linear
(
d_model
,
d_ffn
)
self
.
activation
=
_get_activation_fn
(
activation
,
d_model
=
d_ffn
)
self
.
dropout2
=
nn
.
Dropout
(
dropout
)
self
.
linear2
=
nn
.
Linear
(
d_ffn
,
d_model
)
self
.
dropout3
=
nn
.
Dropout
(
dropout
)
self
.
norm2
=
nn
.
LayerNorm
(
d_model
)
@
staticmethod
def
with_pos_embed
(
tensor
,
pos
):
return
tensor
if
pos
is
None
else
tensor
+
pos
def
forward_ffn
(
self
,
src
):
src2
=
self
.
linear2
(
self
.
dropout2
(
self
.
activation
(
self
.
linear1
(
src
))))
src
=
src
+
self
.
dropout3
(
src2
)
src
=
self
.
norm2
(
src
)
return
src
def
forward
(
self
,
src
,
pos
,
reference_points
,
spatial_shapes
,
level_start_index
,
key_padding_mask
=
None
):
# self attention
# import ipdb; ipdb.set_trace()
src2
=
self
.
self_attn
(
query
=
self
.
with_pos_embed
(
src
,
pos
),
reference_points
=
reference_points
,
value
=
src
,
spatial_shapes
=
spatial_shapes
,
level_start_index
=
level_start_index
,
key_padding_mask
=
key_padding_mask
,
)
src
=
src
+
self
.
dropout1
(
src2
)
src
=
self
.
norm1
(
src
)
# ffn
src
=
self
.
forward_ffn
(
src
)
return
src
class
DeformableTransformerDecoderLayer
(
nn
.
Module
):
def
__init__
(
self
,
d_model
=
256
,
d_ffn
=
1024
,
dropout
=
0.1
,
activation
=
"relu"
,
n_levels
=
4
,
n_heads
=
8
,
n_points
=
4
,
use_text_feat_guide
=
False
,
use_text_cross_attention
=
False
,
):
super
().
__init__
()
# cross attention
self
.
cross_attn
=
MSDeformAttn
(
embed_dim
=
d_model
,
num_levels
=
n_levels
,
num_heads
=
n_heads
,
num_points
=
n_points
,
batch_first
=
True
,
)
self
.
dropout1
=
nn
.
Dropout
(
dropout
)
if
dropout
>
0
else
nn
.
Identity
()
self
.
norm1
=
nn
.
LayerNorm
(
d_model
)
# cross attention text
if
use_text_cross_attention
:
self
.
ca_text
=
nn
.
MultiheadAttention
(
d_model
,
n_heads
,
dropout
=
dropout
)
self
.
catext_dropout
=
nn
.
Dropout
(
dropout
)
if
dropout
>
0
else
nn
.
Identity
()
self
.
catext_norm
=
nn
.
LayerNorm
(
d_model
)
# self attention
self
.
self_attn
=
nn
.
MultiheadAttention
(
d_model
,
n_heads
,
dropout
=
dropout
)
self
.
dropout2
=
nn
.
Dropout
(
dropout
)
if
dropout
>
0
else
nn
.
Identity
()
self
.
norm2
=
nn
.
LayerNorm
(
d_model
)
# ffn
self
.
linear1
=
nn
.
Linear
(
d_model
,
d_ffn
)
self
.
activation
=
_get_activation_fn
(
activation
,
d_model
=
d_ffn
,
batch_dim
=
1
)
self
.
dropout3
=
nn
.
Dropout
(
dropout
)
if
dropout
>
0
else
nn
.
Identity
()
self
.
linear2
=
nn
.
Linear
(
d_ffn
,
d_model
)
self
.
dropout4
=
nn
.
Dropout
(
dropout
)
if
dropout
>
0
else
nn
.
Identity
()
self
.
norm3
=
nn
.
LayerNorm
(
d_model
)
self
.
key_aware_proj
=
None
self
.
use_text_feat_guide
=
use_text_feat_guide
assert
not
use_text_feat_guide
self
.
use_text_cross_attention
=
use_text_cross_attention
def
rm_self_attn_modules
(
self
):
self
.
self_attn
=
None
self
.
dropout2
=
None
self
.
norm2
=
None
@
staticmethod
def
with_pos_embed
(
tensor
,
pos
):
return
tensor
if
pos
is
None
else
tensor
+
pos
def
forward_ffn
(
self
,
tgt
):
with
torch
.
cuda
.
amp
.
autocast
(
enabled
=
False
):
tgt2
=
self
.
linear2
(
self
.
dropout3
(
self
.
activation
(
self
.
linear1
(
tgt
))))
tgt
=
tgt
+
self
.
dropout4
(
tgt2
)
tgt
=
self
.
norm3
(
tgt
)
return
tgt
def
forward
(
self
,
# for tgt
tgt
:
Optional
[
Tensor
],
# nq, bs, d_model
tgt_query_pos
:
Optional
[
Tensor
]
=
None
,
# pos for query. MLP(Sine(pos))
tgt_query_sine_embed
:
Optional
[
Tensor
]
=
None
,
# pos for query. Sine(pos)
tgt_key_padding_mask
:
Optional
[
Tensor
]
=
None
,
tgt_reference_points
:
Optional
[
Tensor
]
=
None
,
# nq, bs, 4
memory_text
:
Optional
[
Tensor
]
=
None
,
# bs, num_token, d_model
text_attention_mask
:
Optional
[
Tensor
]
=
None
,
# bs, num_token
# for memory
memory
:
Optional
[
Tensor
]
=
None
,
# hw, bs, d_model
memory_key_padding_mask
:
Optional
[
Tensor
]
=
None
,
memory_level_start_index
:
Optional
[
Tensor
]
=
None
,
# num_levels
memory_spatial_shapes
:
Optional
[
Tensor
]
=
None
,
# bs, num_levels, 2
memory_pos
:
Optional
[
Tensor
]
=
None
,
# pos for memory
# sa
self_attn_mask
:
Optional
[
Tensor
]
=
None
,
# mask used for self-attention
cross_attn_mask
:
Optional
[
Tensor
]
=
None
,
# mask used for cross-attention
):
"""
Input:
- tgt/tgt_query_pos: nq, bs, d_model
-
"""
assert
cross_attn_mask
is
None
# self attention
if
self
.
self_attn
is
not
None
:
# import ipdb; ipdb.set_trace()
q
=
k
=
self
.
with_pos_embed
(
tgt
,
tgt_query_pos
)
tgt2
=
self
.
self_attn
(
q
,
k
,
tgt
,
attn_mask
=
self_attn_mask
)[
0
]
tgt
=
tgt
+
self
.
dropout2
(
tgt2
)
tgt
=
self
.
norm2
(
tgt
)
if
self
.
use_text_cross_attention
:
tgt2
=
self
.
ca_text
(
self
.
with_pos_embed
(
tgt
,
tgt_query_pos
),
memory_text
.
transpose
(
0
,
1
),
memory_text
.
transpose
(
0
,
1
),
key_padding_mask
=
text_attention_mask
,
)[
0
]
tgt
=
tgt
+
self
.
catext_dropout
(
tgt2
)
tgt
=
self
.
catext_norm
(
tgt
)
tgt2
=
self
.
cross_attn
(
query
=
self
.
with_pos_embed
(
tgt
,
tgt_query_pos
).
transpose
(
0
,
1
),
reference_points
=
tgt_reference_points
.
transpose
(
0
,
1
).
contiguous
(),
value
=
memory
.
transpose
(
0
,
1
),
spatial_shapes
=
memory_spatial_shapes
,
level_start_index
=
memory_level_start_index
,
key_padding_mask
=
memory_key_padding_mask
,
).
transpose
(
0
,
1
)
tgt
=
tgt
+
self
.
dropout1
(
tgt2
)
tgt
=
self
.
norm1
(
tgt
)
# ffn
tgt
=
self
.
forward_ffn
(
tgt
)
return
tgt
def
build_transformer
(
args
):
return
Transformer
(
d_model
=
args
.
hidden_dim
,
dropout
=
args
.
dropout
,
nhead
=
args
.
nheads
,
num_queries
=
args
.
num_queries
,
dim_feedforward
=
args
.
dim_feedforward
,
num_encoder_layers
=
args
.
enc_layers
,
num_decoder_layers
=
args
.
dec_layers
,
normalize_before
=
args
.
pre_norm
,
return_intermediate_dec
=
True
,
query_dim
=
args
.
query_dim
,
activation
=
args
.
transformer_activation
,
num_patterns
=
args
.
num_patterns
,
num_feature_levels
=
args
.
num_feature_levels
,
enc_n_points
=
args
.
enc_n_points
,
dec_n_points
=
args
.
dec_n_points
,
learnable_tgt_init
=
True
,
# two stage
two_stage_type
=
args
.
two_stage_type
,
# ['no', 'standard', 'early']
embed_init_tgt
=
args
.
embed_init_tgt
,
use_text_enhancer
=
args
.
use_text_enhancer
,
use_fusion_layer
=
args
.
use_fusion_layer
,
use_checkpoint
=
args
.
use_checkpoint
,
use_transformer_ckpt
=
args
.
use_transformer_ckpt
,
use_text_cross_attention
=
args
.
use_text_cross_attention
,
text_dropout
=
args
.
text_dropout
,
fusion_dropout
=
args
.
fusion_dropout
,
fusion_droppath
=
args
.
fusion_droppath
,
)
groundingdino/models/GroundingDINO/transformer.py
0 → 100644
View file @
34e4011b
# ------------------------------------------------------------------------
# Grounding DINO
# url: https://github.com/IDEA-Research/GroundingDINO
# Copyright (c) 2023 IDEA. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------
# DINO
# Copyright (c) 2022 IDEA. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------
# Conditional DETR Transformer class.
# Copyright (c) 2021 Microsoft. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------
# Modified from DETR (https://github.com/facebookresearch/detr)
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
# ------------------------------------------------------------------------
from
typing
import
Optional
import
torch
import
torch.utils.checkpoint
as
checkpoint
from
torch
import
Tensor
,
nn
from
groundingdino.util.misc
import
inverse_sigmoid
from
.fuse_modules
import
BiAttentionBlock
from
.ms_deform_attn
import
MultiScaleDeformableAttention
as
MSDeformAttn
from
.transformer_vanilla
import
TransformerEncoderLayer
from
.utils
import
(
MLP
,
_get_activation_fn
,
_get_clones
,
gen_encoder_output_proposals
,
gen_sineembed_for_position
,
get_sine_pos_embed
,
)
class
Transformer
(
nn
.
Module
):
def
__init__
(
self
,
d_model
=
256
,
nhead
=
8
,
num_queries
=
300
,
num_encoder_layers
=
6
,
num_unicoder_layers
=
0
,
num_decoder_layers
=
6
,
dim_feedforward
=
2048
,
dropout
=
0.0
,
activation
=
"relu"
,
normalize_before
=
False
,
return_intermediate_dec
=
False
,
query_dim
=
4
,
num_patterns
=
0
,
# for deformable encoder
num_feature_levels
=
1
,
enc_n_points
=
4
,
dec_n_points
=
4
,
# init query
learnable_tgt_init
=
False
,
# two stage
two_stage_type
=
"no"
,
# ['no', 'standard', 'early', 'combine', 'enceachlayer', 'enclayer1']
embed_init_tgt
=
False
,
# for text
use_text_enhancer
=
False
,
use_fusion_layer
=
False
,
use_checkpoint
=
False
,
use_transformer_ckpt
=
False
,
use_text_cross_attention
=
False
,
text_dropout
=
0.1
,
fusion_dropout
=
0.1
,
fusion_droppath
=
0.0
,
):
super
().
__init__
()
self
.
num_feature_levels
=
num_feature_levels
self
.
num_encoder_layers
=
num_encoder_layers
self
.
num_unicoder_layers
=
num_unicoder_layers
self
.
num_decoder_layers
=
num_decoder_layers
self
.
num_queries
=
num_queries
assert
query_dim
==
4
# choose encoder layer type
encoder_layer
=
DeformableTransformerEncoderLayer
(
d_model
,
dim_feedforward
,
dropout
,
activation
,
num_feature_levels
,
nhead
,
enc_n_points
)
if
use_text_enhancer
:
text_enhance_layer
=
TransformerEncoderLayer
(
d_model
=
d_model
,
nhead
=
nhead
//
2
,
dim_feedforward
=
dim_feedforward
//
2
,
dropout
=
text_dropout
,
)
else
:
text_enhance_layer
=
None
if
use_fusion_layer
:
feature_fusion_layer
=
BiAttentionBlock
(
v_dim
=
d_model
,
l_dim
=
d_model
,
embed_dim
=
dim_feedforward
//
2
,
num_heads
=
nhead
//
2
,
dropout
=
fusion_dropout
,
drop_path
=
fusion_droppath
,
)
else
:
feature_fusion_layer
=
None
encoder_norm
=
nn
.
LayerNorm
(
d_model
)
if
normalize_before
else
None
assert
encoder_norm
is
None
self
.
encoder
=
TransformerEncoder
(
encoder_layer
,
num_encoder_layers
,
d_model
=
d_model
,
num_queries
=
num_queries
,
text_enhance_layer
=
text_enhance_layer
,
feature_fusion_layer
=
feature_fusion_layer
,
use_checkpoint
=
use_checkpoint
,
use_transformer_ckpt
=
use_transformer_ckpt
,
)
# choose decoder layer type
decoder_layer
=
DeformableTransformerDecoderLayer
(
d_model
,
dim_feedforward
,
dropout
,
activation
,
num_feature_levels
,
nhead
,
dec_n_points
,
use_text_cross_attention
=
use_text_cross_attention
,
)
decoder_norm
=
nn
.
LayerNorm
(
d_model
)
self
.
decoder
=
TransformerDecoder
(
decoder_layer
,
num_decoder_layers
,
decoder_norm
,
return_intermediate
=
return_intermediate_dec
,
d_model
=
d_model
,
query_dim
=
query_dim
,
num_feature_levels
=
num_feature_levels
,
)
self
.
d_model
=
d_model
self
.
nhead
=
nhead
self
.
dec_layers
=
num_decoder_layers
self
.
num_queries
=
num_queries
# useful for single stage model only
self
.
num_patterns
=
num_patterns
if
not
isinstance
(
num_patterns
,
int
):
Warning
(
"num_patterns should be int but {}"
.
format
(
type
(
num_patterns
)))
self
.
num_patterns
=
0
if
num_feature_levels
>
1
:
if
self
.
num_encoder_layers
>
0
:
self
.
level_embed
=
nn
.
Parameter
(
torch
.
Tensor
(
num_feature_levels
,
d_model
))
else
:
self
.
level_embed
=
None
self
.
learnable_tgt_init
=
learnable_tgt_init
assert
learnable_tgt_init
,
"why not learnable_tgt_init"
self
.
embed_init_tgt
=
embed_init_tgt
if
(
two_stage_type
!=
"no"
and
embed_init_tgt
)
or
(
two_stage_type
==
"no"
):
self
.
tgt_embed
=
nn
.
Embedding
(
self
.
num_queries
,
d_model
)
nn
.
init
.
normal_
(
self
.
tgt_embed
.
weight
.
data
)
else
:
self
.
tgt_embed
=
None
# for two stage
self
.
two_stage_type
=
two_stage_type
assert
two_stage_type
in
[
"no"
,
"standard"
],
"unknown param {} of two_stage_type"
.
format
(
two_stage_type
)
if
two_stage_type
==
"standard"
:
# anchor selection at the output of encoder
self
.
enc_output
=
nn
.
Linear
(
d_model
,
d_model
)
self
.
enc_output_norm
=
nn
.
LayerNorm
(
d_model
)
self
.
two_stage_wh_embedding
=
None
if
two_stage_type
==
"no"
:
self
.
init_ref_points
(
num_queries
)
# init self.refpoint_embed
self
.
enc_out_class_embed
=
None
self
.
enc_out_bbox_embed
=
None
self
.
_reset_parameters
()
def
_reset_parameters
(
self
):
for
p
in
self
.
parameters
():
if
p
.
dim
()
>
1
:
nn
.
init
.
xavier_uniform_
(
p
)
for
m
in
self
.
modules
():
if
isinstance
(
m
,
MSDeformAttn
):
m
.
_reset_parameters
()
if
self
.
num_feature_levels
>
1
and
self
.
level_embed
is
not
None
:
nn
.
init
.
normal_
(
self
.
level_embed
)
def
get_valid_ratio
(
self
,
mask
):
_
,
H
,
W
=
mask
.
shape
valid_H
=
torch
.
sum
(
~
mask
[:,
:,
0
],
1
)
valid_W
=
torch
.
sum
(
~
mask
[:,
0
,
:],
1
)
valid_ratio_h
=
valid_H
.
float
()
/
H
valid_ratio_w
=
valid_W
.
float
()
/
W
valid_ratio
=
torch
.
stack
([
valid_ratio_w
,
valid_ratio_h
],
-
1
)
return
valid_ratio
def
init_ref_points
(
self
,
use_num_queries
):
self
.
refpoint_embed
=
nn
.
Embedding
(
use_num_queries
,
4
)
def
forward
(
self
,
srcs
,
masks
,
refpoint_embed
,
pos_embeds
,
tgt
,
attn_mask
=
None
,
text_dict
=
None
):
"""
Input:
- srcs: List of multi features [bs, ci, hi, wi]
- masks: List of multi masks [bs, hi, wi]
- refpoint_embed: [bs, num_dn, 4]. None in infer
- pos_embeds: List of multi pos embeds [bs, ci, hi, wi]
- tgt: [bs, num_dn, d_model]. None in infer
"""
# prepare input for encoder
src_flatten
=
[]
mask_flatten
=
[]
lvl_pos_embed_flatten
=
[]
spatial_shapes
=
[]
for
lvl
,
(
src
,
mask
,
pos_embed
)
in
enumerate
(
zip
(
srcs
,
masks
,
pos_embeds
)):
bs
,
c
,
h
,
w
=
src
.
shape
spatial_shape
=
(
h
,
w
)
spatial_shapes
.
append
(
spatial_shape
)
src
=
src
.
flatten
(
2
).
transpose
(
1
,
2
)
# bs, hw, c
mask
=
mask
.
flatten
(
1
)
# bs, hw
pos_embed
=
pos_embed
.
flatten
(
2
).
transpose
(
1
,
2
)
# bs, hw, c
if
self
.
num_feature_levels
>
1
and
self
.
level_embed
is
not
None
:
lvl_pos_embed
=
pos_embed
+
self
.
level_embed
[
lvl
].
view
(
1
,
1
,
-
1
)
else
:
lvl_pos_embed
=
pos_embed
lvl_pos_embed_flatten
.
append
(
lvl_pos_embed
)
src_flatten
.
append
(
src
)
mask_flatten
.
append
(
mask
)
src_flatten
=
torch
.
cat
(
src_flatten
,
1
)
# bs, \sum{hxw}, c
mask_flatten
=
torch
.
cat
(
mask_flatten
,
1
)
# bs, \sum{hxw}
lvl_pos_embed_flatten
=
torch
.
cat
(
lvl_pos_embed_flatten
,
1
)
# bs, \sum{hxw}, c
spatial_shapes
=
torch
.
as_tensor
(
spatial_shapes
,
dtype
=
torch
.
long
,
device
=
src_flatten
.
device
)
level_start_index
=
torch
.
cat
(
(
spatial_shapes
.
new_zeros
((
1
,)),
spatial_shapes
.
prod
(
1
).
cumsum
(
0
)[:
-
1
])
)
valid_ratios
=
torch
.
stack
([
self
.
get_valid_ratio
(
m
)
for
m
in
masks
],
1
)
# two stage
enc_topk_proposals
=
enc_refpoint_embed
=
None
#########################################################
# Begin Encoder
#########################################################
memory
,
memory_text
=
self
.
encoder
(
src_flatten
,
pos
=
lvl_pos_embed_flatten
,
level_start_index
=
level_start_index
,
spatial_shapes
=
spatial_shapes
,
valid_ratios
=
valid_ratios
,
key_padding_mask
=
mask_flatten
,
memory_text
=
text_dict
[
"encoded_text"
],
text_attention_mask
=~
text_dict
[
"text_token_mask"
],
# we ~ the mask . False means use the token; True means pad the token
position_ids
=
text_dict
[
"position_ids"
],
text_self_attention_masks
=
text_dict
[
"text_self_attention_masks"
],
)
#########################################################
# End Encoder
# - memory: bs, \sum{hw}, c
# - mask_flatten: bs, \sum{hw}
# - lvl_pos_embed_flatten: bs, \sum{hw}, c
# - enc_intermediate_output: None or (nenc+1, bs, nq, c) or (nenc, bs, nq, c)
# - enc_intermediate_refpoints: None or (nenc+1, bs, nq, c) or (nenc, bs, nq, c)
#########################################################
text_dict
[
"encoded_text"
]
=
memory_text
# if os.environ.get("SHILONG_AMP_INFNAN_DEBUG") == '1':
# if memory.isnan().any() | memory.isinf().any():
# import ipdb; ipdb.set_trace()
if
self
.
two_stage_type
==
"standard"
:
output_memory
,
output_proposals
=
gen_encoder_output_proposals
(
memory
,
mask_flatten
,
spatial_shapes
)
output_memory
=
self
.
enc_output_norm
(
self
.
enc_output
(
output_memory
))
if
text_dict
is
not
None
:
enc_outputs_class_unselected
=
self
.
enc_out_class_embed
(
output_memory
,
text_dict
)
else
:
enc_outputs_class_unselected
=
self
.
enc_out_class_embed
(
output_memory
)
topk_logits
=
enc_outputs_class_unselected
.
max
(
-
1
)[
0
]
enc_outputs_coord_unselected
=
(
self
.
enc_out_bbox_embed
(
output_memory
)
+
output_proposals
)
# (bs, \sum{hw}, 4) unsigmoid
topk
=
self
.
num_queries
topk_proposals
=
torch
.
topk
(
topk_logits
,
topk
,
dim
=
1
)[
1
]
# bs, nq
# gather boxes
refpoint_embed_undetach
=
torch
.
gather
(
enc_outputs_coord_unselected
,
1
,
topk_proposals
.
unsqueeze
(
-
1
).
repeat
(
1
,
1
,
4
)
)
# unsigmoid
refpoint_embed_
=
refpoint_embed_undetach
.
detach
()
init_box_proposal
=
torch
.
gather
(
output_proposals
,
1
,
topk_proposals
.
unsqueeze
(
-
1
).
repeat
(
1
,
1
,
4
)
).
sigmoid
()
# sigmoid
# gather tgt
tgt_undetach
=
torch
.
gather
(
output_memory
,
1
,
topk_proposals
.
unsqueeze
(
-
1
).
repeat
(
1
,
1
,
self
.
d_model
)
)
if
self
.
embed_init_tgt
:
tgt_
=
(
self
.
tgt_embed
.
weight
[:,
None
,
:].
repeat
(
1
,
bs
,
1
).
transpose
(
0
,
1
)
)
# nq, bs, d_model
else
:
tgt_
=
tgt_undetach
.
detach
()
if
refpoint_embed
is
not
None
:
refpoint_embed
=
torch
.
cat
([
refpoint_embed
,
refpoint_embed_
],
dim
=
1
)
tgt
=
torch
.
cat
([
tgt
,
tgt_
],
dim
=
1
)
else
:
refpoint_embed
,
tgt
=
refpoint_embed_
,
tgt_
elif
self
.
two_stage_type
==
"no"
:
tgt_
=
(
self
.
tgt_embed
.
weight
[:,
None
,
:].
repeat
(
1
,
bs
,
1
).
transpose
(
0
,
1
)
)
# nq, bs, d_model
refpoint_embed_
=
(
self
.
refpoint_embed
.
weight
[:,
None
,
:].
repeat
(
1
,
bs
,
1
).
transpose
(
0
,
1
)
)
# nq, bs, 4
if
refpoint_embed
is
not
None
:
refpoint_embed
=
torch
.
cat
([
refpoint_embed
,
refpoint_embed_
],
dim
=
1
)
tgt
=
torch
.
cat
([
tgt
,
tgt_
],
dim
=
1
)
else
:
refpoint_embed
,
tgt
=
refpoint_embed_
,
tgt_
if
self
.
num_patterns
>
0
:
tgt_embed
=
tgt
.
repeat
(
1
,
self
.
num_patterns
,
1
)
refpoint_embed
=
refpoint_embed
.
repeat
(
1
,
self
.
num_patterns
,
1
)
tgt_pat
=
self
.
patterns
.
weight
[
None
,
:,
:].
repeat_interleave
(
self
.
num_queries
,
1
)
# 1, n_q*n_pat, d_model
tgt
=
tgt_embed
+
tgt_pat
init_box_proposal
=
refpoint_embed_
.
sigmoid
()
else
:
raise
NotImplementedError
(
"unknown two_stage_type {}"
.
format
(
self
.
two_stage_type
))
#########################################################
# End preparing tgt
# - tgt: bs, NQ, d_model
# - refpoint_embed(unsigmoid): bs, NQ, d_model
#########################################################
#########################################################
# Begin Decoder
#########################################################
hs
,
references
=
self
.
decoder
(
tgt
=
tgt
.
transpose
(
0
,
1
),
memory
=
memory
.
transpose
(
0
,
1
),
memory_key_padding_mask
=
mask_flatten
,
pos
=
lvl_pos_embed_flatten
.
transpose
(
0
,
1
),
refpoints_unsigmoid
=
refpoint_embed
.
transpose
(
0
,
1
),
level_start_index
=
level_start_index
,
spatial_shapes
=
spatial_shapes
,
valid_ratios
=
valid_ratios
,
tgt_mask
=
attn_mask
,
memory_text
=
text_dict
[
"encoded_text"
],
text_attention_mask
=~
text_dict
[
"text_token_mask"
],
# we ~ the mask . False means use the token; True means pad the token
)
#########################################################
# End Decoder
# hs: n_dec, bs, nq, d_model
# references: n_dec+1, bs, nq, query_dim
#########################################################
#########################################################
# Begin postprocess
#########################################################
if
self
.
two_stage_type
==
"standard"
:
hs_enc
=
tgt_undetach
.
unsqueeze
(
0
)
ref_enc
=
refpoint_embed_undetach
.
sigmoid
().
unsqueeze
(
0
)
else
:
hs_enc
=
ref_enc
=
None
#########################################################
# End postprocess
# hs_enc: (n_enc+1, bs, nq, d_model) or (1, bs, nq, d_model) or (n_enc, bs, nq, d_model) or None
# ref_enc: (n_enc+1, bs, nq, query_dim) or (1, bs, nq, query_dim) or (n_enc, bs, nq, d_model) or None
#########################################################
return
hs
,
references
,
hs_enc
,
ref_enc
,
init_box_proposal
# hs: (n_dec, bs, nq, d_model)
# references: sigmoid coordinates. (n_dec+1, bs, bq, 4)
# hs_enc: (n_enc+1, bs, nq, d_model) or (1, bs, nq, d_model) or None
# ref_enc: sigmoid coordinates. \
# (n_enc+1, bs, nq, query_dim) or (1, bs, nq, query_dim) or None
class
TransformerEncoder
(
nn
.
Module
):
def
__init__
(
self
,
encoder_layer
,
num_layers
,
d_model
=
256
,
num_queries
=
300
,
enc_layer_share
=
False
,
text_enhance_layer
=
None
,
feature_fusion_layer
=
None
,
use_checkpoint
=
False
,
use_transformer_ckpt
=
False
,
):
"""_summary_
Args:
encoder_layer (_type_): _description_
num_layers (_type_): _description_
norm (_type_, optional): _description_. Defaults to None.
d_model (int, optional): _description_. Defaults to 256.
num_queries (int, optional): _description_. Defaults to 300.
enc_layer_share (bool, optional): _description_. Defaults to False.
"""
super
().
__init__
()
# prepare layers
self
.
layers
=
[]
self
.
text_layers
=
[]
self
.
fusion_layers
=
[]
if
num_layers
>
0
:
self
.
layers
=
_get_clones
(
encoder_layer
,
num_layers
,
layer_share
=
enc_layer_share
)
if
text_enhance_layer
is
not
None
:
self
.
text_layers
=
_get_clones
(
text_enhance_layer
,
num_layers
,
layer_share
=
enc_layer_share
)
if
feature_fusion_layer
is
not
None
:
self
.
fusion_layers
=
_get_clones
(
feature_fusion_layer
,
num_layers
,
layer_share
=
enc_layer_share
)
else
:
self
.
layers
=
[]
del
encoder_layer
if
text_enhance_layer
is
not
None
:
self
.
text_layers
=
[]
del
text_enhance_layer
if
feature_fusion_layer
is
not
None
:
self
.
fusion_layers
=
[]
del
feature_fusion_layer
self
.
query_scale
=
None
self
.
num_queries
=
num_queries
self
.
num_layers
=
num_layers
self
.
d_model
=
d_model
self
.
use_checkpoint
=
use_checkpoint
self
.
use_transformer_ckpt
=
use_transformer_ckpt
@
staticmethod
def
get_reference_points
(
spatial_shapes
,
valid_ratios
,
device
):
reference_points_list
=
[]
for
lvl
,
(
H_
,
W_
)
in
enumerate
(
spatial_shapes
):
ref_y
,
ref_x
=
torch
.
meshgrid
(
torch
.
linspace
(
0.5
,
H_
-
0.5
,
H_
,
dtype
=
torch
.
float32
,
device
=
device
),
torch
.
linspace
(
0.5
,
W_
-
0.5
,
W_
,
dtype
=
torch
.
float32
,
device
=
device
),
)
ref_y
=
ref_y
.
reshape
(
-
1
)[
None
]
/
(
valid_ratios
[:,
None
,
lvl
,
1
]
*
H_
)
ref_x
=
ref_x
.
reshape
(
-
1
)[
None
]
/
(
valid_ratios
[:,
None
,
lvl
,
0
]
*
W_
)
ref
=
torch
.
stack
((
ref_x
,
ref_y
),
-
1
)
reference_points_list
.
append
(
ref
)
reference_points
=
torch
.
cat
(
reference_points_list
,
1
)
reference_points
=
reference_points
[:,
:,
None
]
*
valid_ratios
[:,
None
]
return
reference_points
def
forward
(
self
,
# for images
src
:
Tensor
,
pos
:
Tensor
,
spatial_shapes
:
Tensor
,
level_start_index
:
Tensor
,
valid_ratios
:
Tensor
,
key_padding_mask
:
Tensor
,
# for texts
memory_text
:
Tensor
=
None
,
text_attention_mask
:
Tensor
=
None
,
pos_text
:
Tensor
=
None
,
text_self_attention_masks
:
Tensor
=
None
,
position_ids
:
Tensor
=
None
,
):
"""
Input:
- src: [bs, sum(hi*wi), 256]
- pos: pos embed for src. [bs, sum(hi*wi), 256]
- spatial_shapes: h,w of each level [num_level, 2]
- level_start_index: [num_level] start point of level in sum(hi*wi).
- valid_ratios: [bs, num_level, 2]
- key_padding_mask: [bs, sum(hi*wi)]
- memory_text: bs, n_text, 256
- text_attention_mask: bs, n_text
False for no padding; True for padding
- pos_text: bs, n_text, 256
- position_ids: bs, n_text
Intermedia:
- reference_points: [bs, sum(hi*wi), num_level, 2]
Outpus:
- output: [bs, sum(hi*wi), 256]
"""
output
=
src
# preparation and reshape
if
self
.
num_layers
>
0
:
reference_points
=
self
.
get_reference_points
(
spatial_shapes
,
valid_ratios
,
device
=
src
.
device
)
if
self
.
text_layers
:
# generate pos_text
bs
,
n_text
,
text_dim
=
memory_text
.
shape
if
pos_text
is
None
and
position_ids
is
None
:
pos_text
=
(
torch
.
arange
(
n_text
,
device
=
memory_text
.
device
)
.
float
()
.
unsqueeze
(
0
)
.
unsqueeze
(
-
1
)
.
repeat
(
bs
,
1
,
1
)
)
pos_text
=
get_sine_pos_embed
(
pos_text
,
num_pos_feats
=
256
,
exchange_xy
=
False
)
if
position_ids
is
not
None
:
pos_text
=
get_sine_pos_embed
(
position_ids
[...,
None
],
num_pos_feats
=
256
,
exchange_xy
=
False
)
# main process
for
layer_id
,
layer
in
enumerate
(
self
.
layers
):
# if output.isnan().any() or memory_text.isnan().any():
# if os.environ.get('IPDB_SHILONG_DEBUG', None) == 'INFO':
# import ipdb; ipdb.set_trace()
if
self
.
fusion_layers
:
if
self
.
use_checkpoint
:
output
,
memory_text
=
checkpoint
.
checkpoint
(
self
.
fusion_layers
[
layer_id
],
output
,
memory_text
,
key_padding_mask
,
text_attention_mask
,
)
else
:
output
,
memory_text
=
self
.
fusion_layers
[
layer_id
](
v
=
output
,
l
=
memory_text
,
attention_mask_v
=
key_padding_mask
,
attention_mask_l
=
text_attention_mask
,
)
if
self
.
text_layers
:
memory_text
=
self
.
text_layers
[
layer_id
](
src
=
memory_text
.
transpose
(
0
,
1
),
src_mask
=~
text_self_attention_masks
,
# note we use ~ for mask here
src_key_padding_mask
=
text_attention_mask
,
pos
=
(
pos_text
.
transpose
(
0
,
1
)
if
pos_text
is
not
None
else
None
),
).
transpose
(
0
,
1
)
# main process
if
self
.
use_transformer_ckpt
:
output
=
checkpoint
.
checkpoint
(
layer
,
output
,
pos
,
reference_points
,
spatial_shapes
,
level_start_index
,
key_padding_mask
,
)
else
:
output
=
layer
(
src
=
output
,
pos
=
pos
,
reference_points
=
reference_points
,
spatial_shapes
=
spatial_shapes
,
level_start_index
=
level_start_index
,
key_padding_mask
=
key_padding_mask
,
)
return
output
,
memory_text
class
TransformerDecoder
(
nn
.
Module
):
def
__init__
(
self
,
decoder_layer
,
num_layers
,
norm
=
None
,
return_intermediate
=
False
,
d_model
=
256
,
query_dim
=
4
,
num_feature_levels
=
1
,
):
super
().
__init__
()
if
num_layers
>
0
:
self
.
layers
=
_get_clones
(
decoder_layer
,
num_layers
)
else
:
self
.
layers
=
[]
self
.
num_layers
=
num_layers
self
.
norm
=
norm
self
.
return_intermediate
=
return_intermediate
assert
return_intermediate
,
"support return_intermediate only"
self
.
query_dim
=
query_dim
assert
query_dim
in
[
2
,
4
],
"query_dim should be 2/4 but {}"
.
format
(
query_dim
)
self
.
num_feature_levels
=
num_feature_levels
self
.
ref_point_head
=
MLP
(
query_dim
//
2
*
d_model
,
d_model
,
d_model
,
2
)
self
.
query_pos_sine_scale
=
None
self
.
query_scale
=
None
self
.
bbox_embed
=
None
self
.
class_embed
=
None
self
.
d_model
=
d_model
self
.
ref_anchor_head
=
None
def
forward
(
self
,
tgt
,
memory
,
tgt_mask
:
Optional
[
Tensor
]
=
None
,
memory_mask
:
Optional
[
Tensor
]
=
None
,
tgt_key_padding_mask
:
Optional
[
Tensor
]
=
None
,
memory_key_padding_mask
:
Optional
[
Tensor
]
=
None
,
pos
:
Optional
[
Tensor
]
=
None
,
refpoints_unsigmoid
:
Optional
[
Tensor
]
=
None
,
# num_queries, bs, 2
# for memory
level_start_index
:
Optional
[
Tensor
]
=
None
,
# num_levels
spatial_shapes
:
Optional
[
Tensor
]
=
None
,
# bs, num_levels, 2
valid_ratios
:
Optional
[
Tensor
]
=
None
,
# for text
memory_text
:
Optional
[
Tensor
]
=
None
,
text_attention_mask
:
Optional
[
Tensor
]
=
None
,
):
"""
Input:
- tgt: nq, bs, d_model
- memory: hw, bs, d_model
- pos: hw, bs, d_model
- refpoints_unsigmoid: nq, bs, 2/4
- valid_ratios/spatial_shapes: bs, nlevel, 2
"""
output
=
tgt
intermediate
=
[]
reference_points
=
refpoints_unsigmoid
.
sigmoid
()
ref_points
=
[
reference_points
]
for
layer_id
,
layer
in
enumerate
(
self
.
layers
):
if
reference_points
.
shape
[
-
1
]
==
4
:
reference_points_input
=
(
reference_points
[:,
:,
None
]
*
torch
.
cat
([
valid_ratios
,
valid_ratios
],
-
1
)[
None
,
:]
)
# nq, bs, nlevel, 4
else
:
assert
reference_points
.
shape
[
-
1
]
==
2
reference_points_input
=
reference_points
[:,
:,
None
]
*
valid_ratios
[
None
,
:]
query_sine_embed
=
gen_sineembed_for_position
(
reference_points_input
[:,
:,
0
,
:]
)
# nq, bs, 256*2
# conditional query
raw_query_pos
=
self
.
ref_point_head
(
query_sine_embed
)
# nq, bs, 256
pos_scale
=
self
.
query_scale
(
output
)
if
self
.
query_scale
is
not
None
else
1
query_pos
=
pos_scale
*
raw_query_pos
# if os.environ.get("SHILONG_AMP_INFNAN_DEBUG") == '1':
# if query_pos.isnan().any() | query_pos.isinf().any():
# import ipdb; ipdb.set_trace()
# main process
output
=
layer
(
tgt
=
output
,
tgt_query_pos
=
query_pos
,
tgt_query_sine_embed
=
query_sine_embed
,
tgt_key_padding_mask
=
tgt_key_padding_mask
,
tgt_reference_points
=
reference_points_input
,
memory_text
=
memory_text
,
text_attention_mask
=
text_attention_mask
,
memory
=
memory
,
memory_key_padding_mask
=
memory_key_padding_mask
,
memory_level_start_index
=
level_start_index
,
memory_spatial_shapes
=
spatial_shapes
,
memory_pos
=
pos
,
self_attn_mask
=
tgt_mask
,
cross_attn_mask
=
memory_mask
,
)
if
output
.
isnan
().
any
()
|
output
.
isinf
().
any
():
print
(
f
"output layer_id
{
layer_id
}
is nan"
)
try
:
num_nan
=
output
.
isnan
().
sum
().
item
()
num_inf
=
output
.
isinf
().
sum
().
item
()
print
(
f
"num_nan
{
num_nan
}
, num_inf
{
num_inf
}
"
)
except
Exception
as
e
:
print
(
e
)
# if os.environ.get("SHILONG_AMP_INFNAN_DEBUG") == '1':
# import ipdb; ipdb.set_trace()
# iter update
if
self
.
bbox_embed
is
not
None
:
# box_holder = self.bbox_embed(output)
# box_holder[..., :self.query_dim] += inverse_sigmoid(reference_points)
# new_reference_points = box_holder[..., :self.query_dim].sigmoid()
reference_before_sigmoid
=
inverse_sigmoid
(
reference_points
)
delta_unsig
=
self
.
bbox_embed
[
layer_id
](
output
)
outputs_unsig
=
delta_unsig
+
reference_before_sigmoid
new_reference_points
=
outputs_unsig
.
sigmoid
()
reference_points
=
new_reference_points
.
detach
()
# if layer_id != self.num_layers - 1:
ref_points
.
append
(
new_reference_points
)
intermediate
.
append
(
self
.
norm
(
output
))
return
[
[
itm_out
.
transpose
(
0
,
1
)
for
itm_out
in
intermediate
],
[
itm_refpoint
.
transpose
(
0
,
1
)
for
itm_refpoint
in
ref_points
],
]
class
DeformableTransformerEncoderLayer
(
nn
.
Module
):
def
__init__
(
self
,
d_model
=
256
,
d_ffn
=
1024
,
dropout
=
0.1
,
activation
=
"relu"
,
n_levels
=
4
,
n_heads
=
8
,
n_points
=
4
,
):
super
().
__init__
()
# self attention
self
.
self_attn
=
MSDeformAttn
(
embed_dim
=
d_model
,
num_levels
=
n_levels
,
num_heads
=
n_heads
,
num_points
=
n_points
,
batch_first
=
True
,
)
self
.
dropout1
=
nn
.
Dropout
(
dropout
)
self
.
norm1
=
nn
.
LayerNorm
(
d_model
)
# ffn
self
.
linear1
=
nn
.
Linear
(
d_model
,
d_ffn
)
self
.
activation
=
_get_activation_fn
(
activation
,
d_model
=
d_ffn
)
self
.
dropout2
=
nn
.
Dropout
(
dropout
)
self
.
linear2
=
nn
.
Linear
(
d_ffn
,
d_model
)
self
.
dropout3
=
nn
.
Dropout
(
dropout
)
self
.
norm2
=
nn
.
LayerNorm
(
d_model
)
@
staticmethod
def
with_pos_embed
(
tensor
,
pos
):
return
tensor
if
pos
is
None
else
tensor
+
pos
def
forward_ffn
(
self
,
src
):
src2
=
self
.
linear2
(
self
.
dropout2
(
self
.
activation
(
self
.
linear1
(
src
))))
src
=
src
+
self
.
dropout3
(
src2
)
src
=
self
.
norm2
(
src
)
return
src
def
forward
(
self
,
src
,
pos
,
reference_points
,
spatial_shapes
,
level_start_index
,
key_padding_mask
=
None
):
# self attention
# import ipdb; ipdb.set_trace()
src2
=
self
.
self_attn
(
query
=
self
.
with_pos_embed
(
src
,
pos
),
reference_points
=
reference_points
,
value
=
src
,
spatial_shapes
=
spatial_shapes
,
level_start_index
=
level_start_index
,
key_padding_mask
=
key_padding_mask
,
)
src
=
src
+
self
.
dropout1
(
src2
)
src
=
self
.
norm1
(
src
)
# ffn
src
=
self
.
forward_ffn
(
src
)
return
src
class
DeformableTransformerDecoderLayer
(
nn
.
Module
):
def
__init__
(
self
,
d_model
=
256
,
d_ffn
=
1024
,
dropout
=
0.1
,
activation
=
"relu"
,
n_levels
=
4
,
n_heads
=
8
,
n_points
=
4
,
use_text_feat_guide
=
False
,
use_text_cross_attention
=
False
,
):
super
().
__init__
()
# cross attention
self
.
cross_attn
=
MSDeformAttn
(
embed_dim
=
d_model
,
num_levels
=
n_levels
,
num_heads
=
n_heads
,
num_points
=
n_points
,
batch_first
=
True
,
)
self
.
dropout1
=
nn
.
Dropout
(
dropout
)
if
dropout
>
0
else
nn
.
Identity
()
self
.
norm1
=
nn
.
LayerNorm
(
d_model
)
# cross attention text
if
use_text_cross_attention
:
self
.
ca_text
=
nn
.
MultiheadAttention
(
d_model
,
n_heads
,
dropout
=
dropout
)
self
.
catext_dropout
=
nn
.
Dropout
(
dropout
)
if
dropout
>
0
else
nn
.
Identity
()
self
.
catext_norm
=
nn
.
LayerNorm
(
d_model
)
# self attention
self
.
self_attn
=
nn
.
MultiheadAttention
(
d_model
,
n_heads
,
dropout
=
dropout
)
self
.
dropout2
=
nn
.
Dropout
(
dropout
)
if
dropout
>
0
else
nn
.
Identity
()
self
.
norm2
=
nn
.
LayerNorm
(
d_model
)
# ffn
self
.
linear1
=
nn
.
Linear
(
d_model
,
d_ffn
)
self
.
activation
=
_get_activation_fn
(
activation
,
d_model
=
d_ffn
,
batch_dim
=
1
)
self
.
dropout3
=
nn
.
Dropout
(
dropout
)
if
dropout
>
0
else
nn
.
Identity
()
self
.
linear2
=
nn
.
Linear
(
d_ffn
,
d_model
)
self
.
dropout4
=
nn
.
Dropout
(
dropout
)
if
dropout
>
0
else
nn
.
Identity
()
self
.
norm3
=
nn
.
LayerNorm
(
d_model
)
self
.
key_aware_proj
=
None
self
.
use_text_feat_guide
=
use_text_feat_guide
assert
not
use_text_feat_guide
self
.
use_text_cross_attention
=
use_text_cross_attention
def
rm_self_attn_modules
(
self
):
self
.
self_attn
=
None
self
.
dropout2
=
None
self
.
norm2
=
None
@
staticmethod
def
with_pos_embed
(
tensor
,
pos
):
return
tensor
if
pos
is
None
else
tensor
+
pos
def
forward_ffn
(
self
,
tgt
):
with
torch
.
cuda
.
amp
.
autocast
(
enabled
=
False
):
tgt2
=
self
.
linear2
(
self
.
dropout3
(
self
.
activation
(
self
.
linear1
(
tgt
))))
tgt
=
tgt
+
self
.
dropout4
(
tgt2
)
tgt
=
self
.
norm3
(
tgt
)
return
tgt
def
forward
(
self
,
# for tgt
tgt
:
Optional
[
Tensor
],
# nq, bs, d_model
tgt_query_pos
:
Optional
[
Tensor
]
=
None
,
# pos for query. MLP(Sine(pos))
tgt_query_sine_embed
:
Optional
[
Tensor
]
=
None
,
# pos for query. Sine(pos)
tgt_key_padding_mask
:
Optional
[
Tensor
]
=
None
,
tgt_reference_points
:
Optional
[
Tensor
]
=
None
,
# nq, bs, 4
memory_text
:
Optional
[
Tensor
]
=
None
,
# bs, num_token, d_model
text_attention_mask
:
Optional
[
Tensor
]
=
None
,
# bs, num_token
# for memory
memory
:
Optional
[
Tensor
]
=
None
,
# hw, bs, d_model
memory_key_padding_mask
:
Optional
[
Tensor
]
=
None
,
memory_level_start_index
:
Optional
[
Tensor
]
=
None
,
# num_levels
memory_spatial_shapes
:
Optional
[
Tensor
]
=
None
,
# bs, num_levels, 2
memory_pos
:
Optional
[
Tensor
]
=
None
,
# pos for memory
# sa
self_attn_mask
:
Optional
[
Tensor
]
=
None
,
# mask used for self-attention
cross_attn_mask
:
Optional
[
Tensor
]
=
None
,
# mask used for cross-attention
):
"""
Input:
- tgt/tgt_query_pos: nq, bs, d_model
-
"""
assert
cross_attn_mask
is
None
# self attention
if
self
.
self_attn
is
not
None
:
# import ipdb; ipdb.set_trace()
q
=
k
=
self
.
with_pos_embed
(
tgt
,
tgt_query_pos
)
tgt2
=
self
.
self_attn
(
q
,
k
,
tgt
,
attn_mask
=
self_attn_mask
)[
0
]
tgt
=
tgt
+
self
.
dropout2
(
tgt2
)
tgt
=
self
.
norm2
(
tgt
)
if
self
.
use_text_cross_attention
:
tgt2
=
self
.
ca_text
(
self
.
with_pos_embed
(
tgt
,
tgt_query_pos
),
memory_text
.
transpose
(
0
,
1
),
memory_text
.
transpose
(
0
,
1
),
key_padding_mask
=
text_attention_mask
,
)[
0
]
tgt
=
tgt
+
self
.
catext_dropout
(
tgt2
)
tgt
=
self
.
catext_norm
(
tgt
)
tgt2
=
self
.
cross_attn
(
query
=
self
.
with_pos_embed
(
tgt
,
tgt_query_pos
).
transpose
(
0
,
1
),
reference_points
=
tgt_reference_points
.
transpose
(
0
,
1
).
contiguous
(),
value
=
memory
.
transpose
(
0
,
1
),
spatial_shapes
=
memory_spatial_shapes
,
level_start_index
=
memory_level_start_index
,
key_padding_mask
=
memory_key_padding_mask
,
).
transpose
(
0
,
1
)
tgt
=
tgt
+
self
.
dropout1
(
tgt2
)
tgt
=
self
.
norm1
(
tgt
)
# ffn
tgt
=
self
.
forward_ffn
(
tgt
)
return
tgt
def
build_transformer
(
args
):
return
Transformer
(
d_model
=
args
.
hidden_dim
,
dropout
=
args
.
dropout
,
nhead
=
args
.
nheads
,
num_queries
=
args
.
num_queries
,
dim_feedforward
=
args
.
dim_feedforward
,
num_encoder_layers
=
args
.
enc_layers
,
num_decoder_layers
=
args
.
dec_layers
,
normalize_before
=
args
.
pre_norm
,
return_intermediate_dec
=
True
,
query_dim
=
args
.
query_dim
,
activation
=
args
.
transformer_activation
,
num_patterns
=
args
.
num_patterns
,
num_feature_levels
=
args
.
num_feature_levels
,
enc_n_points
=
args
.
enc_n_points
,
dec_n_points
=
args
.
dec_n_points
,
learnable_tgt_init
=
True
,
# two stage
two_stage_type
=
args
.
two_stage_type
,
# ['no', 'standard', 'early']
embed_init_tgt
=
args
.
embed_init_tgt
,
use_text_enhancer
=
args
.
use_text_enhancer
,
use_fusion_layer
=
args
.
use_fusion_layer
,
use_checkpoint
=
args
.
use_checkpoint
,
use_transformer_ckpt
=
args
.
use_transformer_ckpt
,
use_text_cross_attention
=
args
.
use_text_cross_attention
,
text_dropout
=
args
.
text_dropout
,
fusion_dropout
=
args
.
fusion_dropout
,
fusion_droppath
=
args
.
fusion_droppath
,
)
groundingdino/models/GroundingDINO/transformer_vanilla.py
0 → 100644
View file @
34e4011b
# ------------------------------------------------------------------------
# Grounding DINO
# url: https://github.com/IDEA-Research/GroundingDINO
# Copyright (c) 2023 IDEA. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------
# Copyright (c) Aishwarya Kamath & Nicolas Carion. Licensed under the Apache License 2.0. All Rights Reserved
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
"""
DETR Transformer class.
Copy-paste from torch.nn.Transformer with modifications:
* positional encodings are passed in MHattention
* extra LN at the end of encoder is removed
* decoder returns a stack of activations from all decoding layers
"""
from
typing
import
Optional
import
torch
import
torch.nn.functional
as
F
from
torch
import
Tensor
,
nn
from
.utils
import
(
MLP
,
_get_activation_fn
,
_get_clones
,
gen_encoder_output_proposals
,
gen_sineembed_for_position
,
sigmoid_focal_loss
,
)
class
TextTransformer
(
nn
.
Module
):
def
__init__
(
self
,
num_layers
,
d_model
=
256
,
nheads
=
8
,
dim_feedforward
=
2048
,
dropout
=
0.1
):
super
().
__init__
()
self
.
num_layers
=
num_layers
self
.
d_model
=
d_model
self
.
nheads
=
nheads
self
.
dim_feedforward
=
dim_feedforward
self
.
norm
=
None
single_encoder_layer
=
TransformerEncoderLayer
(
d_model
=
d_model
,
nhead
=
nheads
,
dim_feedforward
=
dim_feedforward
,
dropout
=
dropout
)
self
.
layers
=
_get_clones
(
single_encoder_layer
,
num_layers
)
def
forward
(
self
,
memory_text
:
torch
.
Tensor
,
text_attention_mask
:
torch
.
Tensor
):
"""
Args:
text_attention_mask: bs, num_token
memory_text: bs, num_token, d_model
Raises:
RuntimeError: _description_
Returns:
output: bs, num_token, d_model
"""
output
=
memory_text
.
transpose
(
0
,
1
)
for
layer
in
self
.
layers
:
output
=
layer
(
output
,
src_key_padding_mask
=
text_attention_mask
)
if
self
.
norm
is
not
None
:
output
=
self
.
norm
(
output
)
return
output
.
transpose
(
0
,
1
)
class
TransformerEncoderLayer
(
nn
.
Module
):
def
__init__
(
self
,
d_model
,
nhead
,
dim_feedforward
=
2048
,
dropout
=
0.1
,
activation
=
"relu"
,
normalize_before
=
False
,
):
super
().
__init__
()
self
.
self_attn
=
nn
.
MultiheadAttention
(
d_model
,
nhead
,
dropout
=
dropout
)
# Implementation of Feedforward model
self
.
linear1
=
nn
.
Linear
(
d_model
,
dim_feedforward
)
self
.
dropout
=
nn
.
Dropout
(
dropout
)
self
.
linear2
=
nn
.
Linear
(
dim_feedforward
,
d_model
)
self
.
norm1
=
nn
.
LayerNorm
(
d_model
)
self
.
norm2
=
nn
.
LayerNorm
(
d_model
)
self
.
dropout1
=
nn
.
Dropout
(
dropout
)
self
.
dropout2
=
nn
.
Dropout
(
dropout
)
self
.
activation
=
_get_activation_fn
(
activation
)
self
.
normalize_before
=
normalize_before
self
.
nhead
=
nhead
def
with_pos_embed
(
self
,
tensor
,
pos
:
Optional
[
Tensor
]):
return
tensor
if
pos
is
None
else
tensor
+
pos
def
forward
(
self
,
src
,
src_mask
:
Optional
[
Tensor
]
=
None
,
src_key_padding_mask
:
Optional
[
Tensor
]
=
None
,
pos
:
Optional
[
Tensor
]
=
None
,
):
# repeat attn mask
if
src_mask
.
dim
()
==
3
and
src_mask
.
shape
[
0
]
==
src
.
shape
[
1
]:
# bs, num_q, num_k
src_mask
=
src_mask
.
repeat
(
self
.
nhead
,
1
,
1
)
q
=
k
=
self
.
with_pos_embed
(
src
,
pos
)
src2
=
self
.
self_attn
(
q
,
k
,
value
=
src
,
attn_mask
=
src_mask
)[
0
]
# src2 = self.self_attn(q, k, value=src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask)[0]
src
=
src
+
self
.
dropout1
(
src2
)
src
=
self
.
norm1
(
src
)
src2
=
self
.
linear2
(
self
.
dropout
(
self
.
activation
(
self
.
linear1
(
src
))))
src
=
src
+
self
.
dropout2
(
src2
)
src
=
self
.
norm2
(
src
)
return
src
groundingdino/models/GroundingDINO/utils.py
0 → 100644
View file @
34e4011b
# ------------------------------------------------------------------------
# Grounding DINO
# url: https://github.com/IDEA-Research/GroundingDINO
# Copyright (c) 2023 IDEA. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------
import
copy
import
math
import
torch
import
torch.nn.functional
as
F
from
torch
import
Tensor
,
nn
def
_get_clones
(
module
,
N
,
layer_share
=
False
):
# import ipdb; ipdb.set_trace()
if
layer_share
:
return
nn
.
ModuleList
([
module
for
i
in
range
(
N
)])
else
:
return
nn
.
ModuleList
([
copy
.
deepcopy
(
module
)
for
i
in
range
(
N
)])
def
get_sine_pos_embed
(
pos_tensor
:
torch
.
Tensor
,
num_pos_feats
:
int
=
128
,
temperature
:
int
=
10000
,
exchange_xy
:
bool
=
True
,
):
"""generate sine position embedding from a position tensor
Args:
pos_tensor (torch.Tensor): shape: [..., n].
num_pos_feats (int): projected shape for each float in the tensor.
temperature (int): temperature in the sine/cosine function.
exchange_xy (bool, optional): exchange pos x and pos y.
\
For example, input tensor is [x,y], the results will be [pos(y), pos(x)]. Defaults to True.
Returns:
pos_embed (torch.Tensor): shape: [..., n*num_pos_feats].
"""
scale
=
2
*
math
.
pi
dim_t
=
torch
.
arange
(
num_pos_feats
,
dtype
=
torch
.
float32
,
device
=
pos_tensor
.
device
)
dim_t
=
temperature
**
(
2
*
torch
.
div
(
dim_t
,
2
,
rounding_mode
=
"floor"
)
/
num_pos_feats
)
def
sine_func
(
x
:
torch
.
Tensor
):
sin_x
=
x
*
scale
/
dim_t
sin_x
=
torch
.
stack
((
sin_x
[...,
0
::
2
].
sin
(),
sin_x
[...,
1
::
2
].
cos
()),
dim
=
3
).
flatten
(
2
)
return
sin_x
pos_res
=
[
sine_func
(
x
)
for
x
in
pos_tensor
.
split
([
1
]
*
pos_tensor
.
shape
[
-
1
],
dim
=-
1
)]
if
exchange_xy
:
pos_res
[
0
],
pos_res
[
1
]
=
pos_res
[
1
],
pos_res
[
0
]
pos_res
=
torch
.
cat
(
pos_res
,
dim
=-
1
)
return
pos_res
def
gen_encoder_output_proposals
(
memory
:
Tensor
,
memory_padding_mask
:
Tensor
,
spatial_shapes
:
Tensor
,
learnedwh
=
None
):
"""
Input:
- memory: bs, \sum{hw}, d_model
- memory_padding_mask: bs, \sum{hw}
- spatial_shapes: nlevel, 2
- learnedwh: 2
Output:
- output_memory: bs, \sum{hw}, d_model
- output_proposals: bs, \sum{hw}, 4
"""
N_
,
S_
,
C_
=
memory
.
shape
proposals
=
[]
_cur
=
0
for
lvl
,
(
H_
,
W_
)
in
enumerate
(
spatial_shapes
):
mask_flatten_
=
memory_padding_mask
[:,
_cur
:
(
_cur
+
H_
*
W_
)].
view
(
N_
,
H_
,
W_
,
1
)
valid_H
=
torch
.
sum
(
~
mask_flatten_
[:,
:,
0
,
0
],
1
)
valid_W
=
torch
.
sum
(
~
mask_flatten_
[:,
0
,
:,
0
],
1
)
# import ipdb; ipdb.set_trace()
grid_y
,
grid_x
=
torch
.
meshgrid
(
torch
.
linspace
(
0
,
H_
-
1
,
H_
,
dtype
=
torch
.
float32
,
device
=
memory
.
device
),
torch
.
linspace
(
0
,
W_
-
1
,
W_
,
dtype
=
torch
.
float32
,
device
=
memory
.
device
),
)
grid
=
torch
.
cat
([
grid_x
.
unsqueeze
(
-
1
),
grid_y
.
unsqueeze
(
-
1
)],
-
1
)
# H_, W_, 2
scale
=
torch
.
cat
([
valid_W
.
unsqueeze
(
-
1
),
valid_H
.
unsqueeze
(
-
1
)],
1
).
view
(
N_
,
1
,
1
,
2
)
grid
=
(
grid
.
unsqueeze
(
0
).
expand
(
N_
,
-
1
,
-
1
,
-
1
)
+
0.5
)
/
scale
if
learnedwh
is
not
None
:
# import ipdb; ipdb.set_trace()
wh
=
torch
.
ones_like
(
grid
)
*
learnedwh
.
sigmoid
()
*
(
2.0
**
lvl
)
else
:
wh
=
torch
.
ones_like
(
grid
)
*
0.05
*
(
2.0
**
lvl
)
# scale = torch.cat([W_[None].unsqueeze(-1), H_[None].unsqueeze(-1)], 1).view(1, 1, 1, 2).repeat(N_, 1, 1, 1)
# grid = (grid.unsqueeze(0).expand(N_, -1, -1, -1) + 0.5) / scale
# wh = torch.ones_like(grid) / scale
proposal
=
torch
.
cat
((
grid
,
wh
),
-
1
).
view
(
N_
,
-
1
,
4
)
proposals
.
append
(
proposal
)
_cur
+=
H_
*
W_
# import ipdb; ipdb.set_trace()
output_proposals
=
torch
.
cat
(
proposals
,
1
)
output_proposals_valid
=
((
output_proposals
>
0.01
)
&
(
output_proposals
<
0.99
)).
all
(
-
1
,
keepdim
=
True
)
output_proposals
=
torch
.
log
(
output_proposals
/
(
1
-
output_proposals
))
# unsigmoid
output_proposals
=
output_proposals
.
masked_fill
(
memory_padding_mask
.
unsqueeze
(
-
1
),
float
(
"inf"
))
output_proposals
=
output_proposals
.
masked_fill
(
~
output_proposals_valid
,
float
(
"inf"
))
output_memory
=
memory
output_memory
=
output_memory
.
masked_fill
(
memory_padding_mask
.
unsqueeze
(
-
1
),
float
(
0
))
output_memory
=
output_memory
.
masked_fill
(
~
output_proposals_valid
,
float
(
0
))
# output_memory = output_memory.masked_fill(memory_padding_mask.unsqueeze(-1), float('inf'))
# output_memory = output_memory.masked_fill(~output_proposals_valid, float('inf'))
return
output_memory
,
output_proposals
class
RandomBoxPerturber
:
def
__init__
(
self
,
x_noise_scale
=
0.2
,
y_noise_scale
=
0.2
,
w_noise_scale
=
0.2
,
h_noise_scale
=
0.2
)
->
None
:
self
.
noise_scale
=
torch
.
Tensor
(
[
x_noise_scale
,
y_noise_scale
,
w_noise_scale
,
h_noise_scale
]
)
def
__call__
(
self
,
refanchors
:
Tensor
)
->
Tensor
:
nq
,
bs
,
query_dim
=
refanchors
.
shape
device
=
refanchors
.
device
noise_raw
=
torch
.
rand_like
(
refanchors
)
noise_scale
=
self
.
noise_scale
.
to
(
device
)[:
query_dim
]
new_refanchors
=
refanchors
*
(
1
+
(
noise_raw
-
0.5
)
*
noise_scale
)
return
new_refanchors
.
clamp_
(
0
,
1
)
def
sigmoid_focal_loss
(
inputs
,
targets
,
num_boxes
,
alpha
:
float
=
0.25
,
gamma
:
float
=
2
,
no_reduction
=
False
):
"""
Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002.
Args:
inputs: A float tensor of arbitrary shape.
The predictions for each example.
targets: A float tensor with the same shape as inputs. Stores the binary
classification label for each element in inputs
(0 for the negative class and 1 for the positive class).
alpha: (optional) Weighting factor in range (0,1) to balance
positive vs negative examples. Default = -1 (no weighting).
gamma: Exponent of the modulating factor (1 - p_t) to
balance easy vs hard examples.
Returns:
Loss tensor
"""
prob
=
inputs
.
sigmoid
()
ce_loss
=
F
.
binary_cross_entropy_with_logits
(
inputs
,
targets
,
reduction
=
"none"
)
p_t
=
prob
*
targets
+
(
1
-
prob
)
*
(
1
-
targets
)
loss
=
ce_loss
*
((
1
-
p_t
)
**
gamma
)
if
alpha
>=
0
:
alpha_t
=
alpha
*
targets
+
(
1
-
alpha
)
*
(
1
-
targets
)
loss
=
alpha_t
*
loss
if
no_reduction
:
return
loss
return
loss
.
mean
(
1
).
sum
()
/
num_boxes
class
MLP
(
nn
.
Module
):
"""Very simple multi-layer perceptron (also called FFN)"""
def
__init__
(
self
,
input_dim
,
hidden_dim
,
output_dim
,
num_layers
):
super
().
__init__
()
self
.
num_layers
=
num_layers
h
=
[
hidden_dim
]
*
(
num_layers
-
1
)
self
.
layers
=
nn
.
ModuleList
(
nn
.
Linear
(
n
,
k
)
for
n
,
k
in
zip
([
input_dim
]
+
h
,
h
+
[
output_dim
])
)
def
forward
(
self
,
x
):
for
i
,
layer
in
enumerate
(
self
.
layers
):
x
=
F
.
relu
(
layer
(
x
))
if
i
<
self
.
num_layers
-
1
else
layer
(
x
)
return
x
def
_get_activation_fn
(
activation
,
d_model
=
256
,
batch_dim
=
0
):
"""Return an activation function given a string"""
if
activation
==
"relu"
:
return
F
.
relu
if
activation
==
"gelu"
:
return
F
.
gelu
if
activation
==
"glu"
:
return
F
.
glu
if
activation
==
"prelu"
:
return
nn
.
PReLU
()
if
activation
==
"selu"
:
return
F
.
selu
raise
RuntimeError
(
f
"activation should be relu/gelu, not
{
activation
}
."
)
def
gen_sineembed_for_position
(
pos_tensor
):
# n_query, bs, _ = pos_tensor.size()
# sineembed_tensor = torch.zeros(n_query, bs, 256)
scale
=
2
*
math
.
pi
dim_t
=
torch
.
arange
(
128
,
dtype
=
torch
.
float32
,
device
=
pos_tensor
.
device
)
dim_t
=
10000
**
(
2
*
(
torch
.
div
(
dim_t
,
2
,
rounding_mode
=
'floor'
))
/
128
)
x_embed
=
pos_tensor
[:,
:,
0
]
*
scale
y_embed
=
pos_tensor
[:,
:,
1
]
*
scale
pos_x
=
x_embed
[:,
:,
None
]
/
dim_t
pos_y
=
y_embed
[:,
:,
None
]
/
dim_t
pos_x
=
torch
.
stack
((
pos_x
[:,
:,
0
::
2
].
sin
(),
pos_x
[:,
:,
1
::
2
].
cos
()),
dim
=
3
).
flatten
(
2
)
pos_y
=
torch
.
stack
((
pos_y
[:,
:,
0
::
2
].
sin
(),
pos_y
[:,
:,
1
::
2
].
cos
()),
dim
=
3
).
flatten
(
2
)
if
pos_tensor
.
size
(
-
1
)
==
2
:
pos
=
torch
.
cat
((
pos_y
,
pos_x
),
dim
=
2
)
elif
pos_tensor
.
size
(
-
1
)
==
4
:
w_embed
=
pos_tensor
[:,
:,
2
]
*
scale
pos_w
=
w_embed
[:,
:,
None
]
/
dim_t
pos_w
=
torch
.
stack
((
pos_w
[:,
:,
0
::
2
].
sin
(),
pos_w
[:,
:,
1
::
2
].
cos
()),
dim
=
3
).
flatten
(
2
)
h_embed
=
pos_tensor
[:,
:,
3
]
*
scale
pos_h
=
h_embed
[:,
:,
None
]
/
dim_t
pos_h
=
torch
.
stack
((
pos_h
[:,
:,
0
::
2
].
sin
(),
pos_h
[:,
:,
1
::
2
].
cos
()),
dim
=
3
).
flatten
(
2
)
pos
=
torch
.
cat
((
pos_y
,
pos_x
,
pos_w
,
pos_h
),
dim
=
2
)
else
:
raise
ValueError
(
"Unknown pos_tensor shape(-1):{}"
.
format
(
pos_tensor
.
size
(
-
1
)))
return
pos
class
ContrastiveEmbed
(
nn
.
Module
):
def
__init__
(
self
,
max_text_len
=
256
):
"""
Args:
max_text_len: max length of text.
"""
super
().
__init__
()
self
.
max_text_len
=
max_text_len
def
forward
(
self
,
x
,
text_dict
):
"""_summary_
Args:
x (_type_): _description_
text_dict (_type_): _description_
{
'encoded_text': encoded_text, # bs, 195, d_model
'text_token_mask': text_token_mask, # bs, 195
# True for used tokens. False for padding tokens
}
Returns:
_type_: _description_
"""
assert
isinstance
(
text_dict
,
dict
)
y
=
text_dict
[
"encoded_text"
]
text_token_mask
=
text_dict
[
"text_token_mask"
]
res
=
x
@
y
.
transpose
(
-
1
,
-
2
)
res
.
masked_fill_
(
~
text_token_mask
[:,
None
,
:],
float
(
"-inf"
))
# padding to max_text_len
new_res
=
torch
.
full
((
*
res
.
shape
[:
-
1
],
self
.
max_text_len
),
float
(
"-inf"
),
device
=
res
.
device
)
new_res
[...,
:
res
.
shape
[
-
1
]]
=
res
return
new_res
groundingdino/models/__init__.py
0 → 100644
View file @
34e4011b
# ------------------------------------------------------------------------
# Grounding DINO
# url: https://github.com/IDEA-Research/GroundingDINO
# Copyright (c) 2023 IDEA. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
from
.GroundingDINO
import
build_groundingdino
def
build_model
(
args
):
# we use register to maintain models from catdet6 on.
from
.registry
import
MODULE_BUILD_FUNCS
assert
args
.
modelname
in
MODULE_BUILD_FUNCS
.
_module_dict
build_func
=
MODULE_BUILD_FUNCS
.
get
(
args
.
modelname
)
model
=
build_func
(
args
)
return
model
groundingdino/models/registry.py
0 → 100644
View file @
34e4011b
# ------------------------------------------------------------------------
# Grounding DINO
# url: https://github.com/IDEA-Research/GroundingDINO
# Copyright (c) 2023 IDEA. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------
# -*- coding: utf-8 -*-
# @Author: Yihao Chen
# @Date: 2021-08-16 16:03:17
# @Last Modified by: Shilong Liu
# @Last Modified time: 2022-01-23 15:26
# modified from mmcv
import
inspect
from
functools
import
partial
class
Registry
(
object
):
def
__init__
(
self
,
name
):
self
.
_name
=
name
self
.
_module_dict
=
dict
()
def
__repr__
(
self
):
format_str
=
self
.
__class__
.
__name__
+
"(name={}, items={})"
.
format
(
self
.
_name
,
list
(
self
.
_module_dict
.
keys
())
)
return
format_str
def
__len__
(
self
):
return
len
(
self
.
_module_dict
)
@
property
def
name
(
self
):
return
self
.
_name
@
property
def
module_dict
(
self
):
return
self
.
_module_dict
def
get
(
self
,
key
):
return
self
.
_module_dict
.
get
(
key
,
None
)
def
registe_with_name
(
self
,
module_name
=
None
,
force
=
False
):
return
partial
(
self
.
register
,
module_name
=
module_name
,
force
=
force
)
def
register
(
self
,
module_build_function
,
module_name
=
None
,
force
=
False
):
"""Register a module build function.
Args:
module (:obj:`nn.Module`): Module to be registered.
"""
if
not
inspect
.
isfunction
(
module_build_function
):
raise
TypeError
(
"module_build_function must be a function, but got {}"
.
format
(
type
(
module_build_function
)
)
)
if
module_name
is
None
:
module_name
=
module_build_function
.
__name__
if
not
force
and
module_name
in
self
.
_module_dict
:
raise
KeyError
(
"{} is already registered in {}"
.
format
(
module_name
,
self
.
name
))
self
.
_module_dict
[
module_name
]
=
module_build_function
return
module_build_function
MODULE_BUILD_FUNCS
=
Registry
(
"model build functions"
)
groundingdino/util/__init__.py
0 → 100644
View file @
34e4011b
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
Prev
1
2
3
4
5
6
7
8
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment