Unverified Commit 83e96dc0 authored by Sangbum Daniel Choi's avatar Sangbum Daniel Choi Committed by GitHub
Browse files

Add cuda_custom_kernel in DETA (#28989)

* enable graident checkpointing in DetaObjectDetection

* fix missing part in original DETA

* make style

* make fix-copies

* Revert "make fix-copies"

This reverts commit 4041c86c29248f1673e8173b677c20b5a4511358.

* remove fix-copies of DetaDecoder

* enable swin gradient checkpointing

* fix gradient checkpointing in donut_swin

* add tests for deta/swin/donut

* Revert "fix gradient checkpointing in donut_swin"

This reverts commit 1cf345e34d3cc0e09eb800d9895805b1dd9b474d.

* change supports_gradient_checkpointing pipeline to PreTrainedModel

* Revert "add tests for deta/swin/donut"

This reverts commit 6056ffbb1eddc3cb3a99e4ebb231ae3edf295f5b.

* Revert "Revert "fix gradient checkpointing in donut_swin""

This reverts commit 24e25d0a14891241de58a0d86f817d0b5d2a341f.

* Simple revert

* enable deformable detr gradient checkpointing

* add gradient in encoder

* add cuda_custom_kernel function in MSDA

* make style and fix input of DetaMSDA

* make fix-copies

* remove n_levels in input of DetaMSDA

* minor changes

* refactor custom_cuda_kernel like yoso format
https://github.com/huggingface/transformers/blob/0507e69d34f8902422eb4977ec066dd6bef179a0/src/transformers/models/yoso/modeling_yoso.py#L53
parent f3788b09
/*!
**************************************************************************************************
* Deformable DETR
* Copyright (c) 2020 SenseTime. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
**************************************************************************************************
* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
**************************************************************************************************
*/
#include <vector>
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
at::Tensor
ms_deform_attn_cpu_forward(
const at::Tensor &value,
const at::Tensor &spatial_shapes,
const at::Tensor &level_start_index,
const at::Tensor &sampling_loc,
const at::Tensor &attn_weight,
const int im2col_step)
{
AT_ERROR("Not implement on cpu");
}
std::vector<at::Tensor>
ms_deform_attn_cpu_backward(
const at::Tensor &value,
const at::Tensor &spatial_shapes,
const at::Tensor &level_start_index,
const at::Tensor &sampling_loc,
const at::Tensor &attn_weight,
const at::Tensor &grad_output,
const int im2col_step)
{
AT_ERROR("Not implement on cpu");
}
/*!
**************************************************************************************************
* Deformable DETR
* Copyright (c) 2020 SenseTime. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
**************************************************************************************************
* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
**************************************************************************************************
*/
#pragma once
#include <torch/extension.h>
at::Tensor
ms_deform_attn_cpu_forward(
const at::Tensor &value,
const at::Tensor &spatial_shapes,
const at::Tensor &level_start_index,
const at::Tensor &sampling_loc,
const at::Tensor &attn_weight,
const int im2col_step);
std::vector<at::Tensor>
ms_deform_attn_cpu_backward(
const at::Tensor &value,
const at::Tensor &spatial_shapes,
const at::Tensor &level_start_index,
const at::Tensor &sampling_loc,
const at::Tensor &attn_weight,
const at::Tensor &grad_output,
const int im2col_step);
/*!
**************************************************************************************************
* Deformable DETR
* Copyright (c) 2020 SenseTime. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
**************************************************************************************************
* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
**************************************************************************************************
*/
#include <vector>
#include "cuda/ms_deform_im2col_cuda.cuh"
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <cuda.h>
#include <cuda_runtime.h>
#pragma once
#include <torch/extension.h>
at::Tensor ms_deform_attn_cuda_forward(
const at::Tensor &value,
const at::Tensor &spatial_shapes,
const at::Tensor &level_start_index,
const at::Tensor &sampling_loc,
const at::Tensor &attn_weight,
const int im2col_step)
{
AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous");
AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous");
AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous");
AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous");
AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous");
AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor");
AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor");
AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor");
AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor");
AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor");
const int batch = value.size(0);
const int spatial_size = value.size(1);
const int num_heads = value.size(2);
const int channels = value.size(3);
const int num_levels = spatial_shapes.size(0);
const int num_query = sampling_loc.size(1);
const int num_point = sampling_loc.size(4);
const int im2col_step_ = std::min(batch, im2col_step);
AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_);
auto output = at::zeros({batch, num_query, num_heads, channels}, value.options());
const int batch_n = im2col_step_;
auto output_n = output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels});
auto per_value_size = spatial_size * num_heads * channels;
auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2;
auto per_attn_weight_size = num_query * num_heads * num_levels * num_point;
for (int n = 0; n < batch/im2col_step_; ++n)
{
auto columns = output_n.select(0, n);
AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_forward_cuda", ([&] {
ms_deformable_im2col_cuda(at::cuda::getCurrentCUDAStream(),
value.data<scalar_t>() + n * im2col_step_ * per_value_size,
spatial_shapes.data<int64_t>(),
level_start_index.data<int64_t>(),
sampling_loc.data<scalar_t>() + n * im2col_step_ * per_sample_loc_size,
attn_weight.data<scalar_t>() + n * im2col_step_ * per_attn_weight_size,
batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point,
columns.data<scalar_t>());
}));
}
output = output.view({batch, num_query, num_heads*channels});
return output;
}
std::vector<at::Tensor> ms_deform_attn_cuda_backward(
const at::Tensor &value,
const at::Tensor &spatial_shapes,
const at::Tensor &level_start_index,
const at::Tensor &sampling_loc,
const at::Tensor &attn_weight,
const at::Tensor &grad_output,
const int im2col_step)
{
AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous");
AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous");
AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous");
AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous");
AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous");
AT_ASSERTM(grad_output.is_contiguous(), "grad_output tensor has to be contiguous");
AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor");
AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor");
AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor");
AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor");
AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor");
AT_ASSERTM(grad_output.type().is_cuda(), "grad_output must be a CUDA tensor");
const int batch = value.size(0);
const int spatial_size = value.size(1);
const int num_heads = value.size(2);
const int channels = value.size(3);
const int num_levels = spatial_shapes.size(0);
const int num_query = sampling_loc.size(1);
const int num_point = sampling_loc.size(4);
const int im2col_step_ = std::min(batch, im2col_step);
AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_);
auto grad_value = at::zeros_like(value);
auto grad_sampling_loc = at::zeros_like(sampling_loc);
auto grad_attn_weight = at::zeros_like(attn_weight);
const int batch_n = im2col_step_;
auto per_value_size = spatial_size * num_heads * channels;
auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2;
auto per_attn_weight_size = num_query * num_heads * num_levels * num_point;
auto grad_output_n = grad_output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels});
for (int n = 0; n < batch/im2col_step_; ++n)
{
auto grad_output_g = grad_output_n.select(0, n);
AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_backward_cuda", ([&] {
ms_deformable_col2im_cuda(at::cuda::getCurrentCUDAStream(),
grad_output_g.data<scalar_t>(),
value.data<scalar_t>() + n * im2col_step_ * per_value_size,
spatial_shapes.data<int64_t>(),
level_start_index.data<int64_t>(),
sampling_loc.data<scalar_t>() + n * im2col_step_ * per_sample_loc_size,
attn_weight.data<scalar_t>() + n * im2col_step_ * per_attn_weight_size,
batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point,
grad_value.data<scalar_t>() + n * im2col_step_ * per_value_size,
grad_sampling_loc.data<scalar_t>() + n * im2col_step_ * per_sample_loc_size,
grad_attn_weight.data<scalar_t>() + n * im2col_step_ * per_attn_weight_size);
}));
}
return {
grad_value, grad_sampling_loc, grad_attn_weight
};
}
This diff is collapsed.
/*!
**************************************************************************************************
* Deformable DETR
* Copyright (c) 2020 SenseTime. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
**************************************************************************************************
* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
**************************************************************************************************
*/
#pragma once
#include <torch/extension.h>
at::Tensor ms_deform_attn_cuda_forward(
const at::Tensor &value,
const at::Tensor &spatial_shapes,
const at::Tensor &level_start_index,
const at::Tensor &sampling_loc,
const at::Tensor &attn_weight,
const int im2col_step);
std::vector<at::Tensor> ms_deform_attn_cuda_backward(
const at::Tensor &value,
const at::Tensor &spatial_shapes,
const at::Tensor &level_start_index,
const at::Tensor &sampling_loc,
const at::Tensor &attn_weight,
const at::Tensor &grad_output,
const int im2col_step);
This diff is collapsed.
/*!
**************************************************************************************************
* Deformable DETR
* Copyright (c) 2020 SenseTime. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
**************************************************************************************************
* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
**************************************************************************************************
*/
#pragma once
#include "cpu/ms_deform_attn_cpu.h"
#ifdef WITH_CUDA
#include "cuda/ms_deform_attn_cuda.h"
#endif
at::Tensor
ms_deform_attn_forward(
const at::Tensor &value,
const at::Tensor &spatial_shapes,
const at::Tensor &level_start_index,
const at::Tensor &sampling_loc,
const at::Tensor &attn_weight,
const int im2col_step)
{
if (value.type().is_cuda())
{
#ifdef WITH_CUDA
return ms_deform_attn_cuda_forward(
value, spatial_shapes, level_start_index, sampling_loc, attn_weight, im2col_step);
#else
AT_ERROR("Not compiled with GPU support");
#endif
}
AT_ERROR("Not implemented on the CPU");
}
std::vector<at::Tensor>
ms_deform_attn_backward(
const at::Tensor &value,
const at::Tensor &spatial_shapes,
const at::Tensor &level_start_index,
const at::Tensor &sampling_loc,
const at::Tensor &attn_weight,
const at::Tensor &grad_output,
const int im2col_step)
{
if (value.type().is_cuda())
{
#ifdef WITH_CUDA
return ms_deform_attn_cuda_backward(
value, spatial_shapes, level_start_index, sampling_loc, attn_weight, grad_output, im2col_step);
#else
AT_ERROR("Not compiled with GPU support");
#endif
}
AT_ERROR("Not implemented on the CPU");
}
/*!
**************************************************************************************************
* Deformable DETR
* Copyright (c) 2020 SenseTime. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
**************************************************************************************************
* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
**************************************************************************************************
*/
#include "ms_deform_attn.h"
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("ms_deform_attn_forward", &ms_deform_attn_forward, "ms_deform_attn_forward");
m.def("ms_deform_attn_backward", &ms_deform_attn_backward, "ms_deform_attn_backward");
}
\ No newline at end of file
...@@ -125,6 +125,9 @@ class DetaConfig(PretrainedConfig): ...@@ -125,6 +125,9 @@ class DetaConfig(PretrainedConfig):
Whether to assign each prediction i to the highest overlapping ground truth object if the overlap is larger than a threshold 0.7. Whether to assign each prediction i to the highest overlapping ground truth object if the overlap is larger than a threshold 0.7.
assign_second_stage (`bool`, *optional*, defaults to `True`): assign_second_stage (`bool`, *optional*, defaults to `True`):
Whether to assign second assignment procedure in the second stage closely follows the first stage assignment procedure. Whether to assign second assignment procedure in the second stage closely follows the first stage assignment procedure.
disable_custom_kernels (`bool`, *optional*, defaults to `True`):
Disable the use of custom CUDA and CPU kernels. This option is necessary for the ONNX export, as custom
kernels are not supported by PyTorch ONNX export.
Examples: Examples:
...@@ -191,6 +194,7 @@ class DetaConfig(PretrainedConfig): ...@@ -191,6 +194,7 @@ class DetaConfig(PretrainedConfig):
giou_loss_coefficient=2, giou_loss_coefficient=2,
eos_coefficient=0.1, eos_coefficient=0.1,
focal_alpha=0.25, focal_alpha=0.25,
disable_custom_kernels=True,
**kwargs, **kwargs,
): ):
if use_pretrained_backbone: if use_pretrained_backbone:
...@@ -256,6 +260,7 @@ class DetaConfig(PretrainedConfig): ...@@ -256,6 +260,7 @@ class DetaConfig(PretrainedConfig):
self.giou_loss_coefficient = giou_loss_coefficient self.giou_loss_coefficient = giou_loss_coefficient
self.eos_coefficient = eos_coefficient self.eos_coefficient = eos_coefficient
self.focal_alpha = focal_alpha self.focal_alpha = focal_alpha
self.disable_custom_kernels = disable_custom_kernels
super().__init__(is_encoder_decoder=is_encoder_decoder, **kwargs) super().__init__(is_encoder_decoder=is_encoder_decoder, **kwargs)
@property @property
......
...@@ -17,13 +17,17 @@ ...@@ -17,13 +17,17 @@
import copy import copy
import math import math
import os
import warnings import warnings
from dataclasses import dataclass from dataclasses import dataclass
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Union from typing import Dict, List, Optional, Tuple, Union
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch import Tensor, nn from torch import Tensor, nn
from torch.autograd import Function
from torch.autograd.function import once_differentiable
from ...activations import ACT2FN from ...activations import ACT2FN
from ...file_utils import ( from ...file_utils import (
...@@ -31,6 +35,7 @@ from ...file_utils import ( ...@@ -31,6 +35,7 @@ from ...file_utils import (
add_start_docstrings, add_start_docstrings,
add_start_docstrings_to_model_forward, add_start_docstrings_to_model_forward,
is_scipy_available, is_scipy_available,
is_torch_cuda_available,
is_vision_available, is_vision_available,
replace_return_docstrings, replace_return_docstrings,
) )
...@@ -38,7 +43,7 @@ from ...modeling_attn_mask_utils import _prepare_4d_attention_mask ...@@ -38,7 +43,7 @@ from ...modeling_attn_mask_utils import _prepare_4d_attention_mask
from ...modeling_outputs import BaseModelOutput from ...modeling_outputs import BaseModelOutput
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import meshgrid from ...pytorch_utils import meshgrid
from ...utils import is_accelerate_available, is_torchvision_available, logging, requires_backends from ...utils import is_accelerate_available, is_ninja_available, is_torchvision_available, logging, requires_backends
from ...utils.backbone_utils import load_backbone from ...utils.backbone_utils import load_backbone
from .configuration_deta import DetaConfig from .configuration_deta import DetaConfig
...@@ -46,6 +51,99 @@ from .configuration_deta import DetaConfig ...@@ -46,6 +51,99 @@ from .configuration_deta import DetaConfig
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
def load_cuda_kernels():
from torch.utils.cpp_extension import load
root = Path(__file__).resolve().parent.parent.parent / "kernels" / "deta"
src_files = [
root / filename
for filename in [
"vision.cpp",
os.path.join("cpu", "ms_deform_attn_cpu.cpp"),
os.path.join("cuda", "ms_deform_attn_cuda.cu"),
]
]
load(
"MultiScaleDeformableAttention",
src_files,
with_cuda=True,
extra_include_paths=[str(root)],
extra_cflags=["-DWITH_CUDA=1"],
extra_cuda_cflags=[
"-DCUDA_HAS_FP16=1",
"-D__CUDA_NO_HALF_OPERATORS__",
"-D__CUDA_NO_HALF_CONVERSIONS__",
"-D__CUDA_NO_HALF2_OPERATORS__",
],
)
import MultiScaleDeformableAttention as MSDA
return MSDA
# Move this to not compile only when importing, this needs to happen later, like in __init__.
if is_torch_cuda_available() and is_ninja_available():
logger.info("Loading custom CUDA kernels...")
try:
MultiScaleDeformableAttention = load_cuda_kernels()
except Exception as e:
logger.warning(f"Could not load the custom kernel for multi-scale deformable attention: {e}")
MultiScaleDeformableAttention = None
else:
MultiScaleDeformableAttention = None
# Copied from transformers.models.deformable_detr.modeling_deformable_detr.MultiScaleDeformableAttentionFunction
class MultiScaleDeformableAttentionFunction(Function):
@staticmethod
def forward(
context,
value,
value_spatial_shapes,
value_level_start_index,
sampling_locations,
attention_weights,
im2col_step,
):
context.im2col_step = im2col_step
output = MultiScaleDeformableAttention.ms_deform_attn_forward(
value,
value_spatial_shapes,
value_level_start_index,
sampling_locations,
attention_weights,
context.im2col_step,
)
context.save_for_backward(
value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights
)
return output
@staticmethod
@once_differentiable
def backward(context, grad_output):
(
value,
value_spatial_shapes,
value_level_start_index,
sampling_locations,
attention_weights,
) = context.saved_tensors
grad_value, grad_sampling_loc, grad_attn_weight = MultiScaleDeformableAttention.ms_deform_attn_backward(
value,
value_spatial_shapes,
value_level_start_index,
sampling_locations,
attention_weights,
grad_output,
context.im2col_step,
)
return grad_value, None, None, grad_sampling_loc, grad_attn_weight, None
if is_accelerate_available(): if is_accelerate_available():
from accelerate import PartialState from accelerate import PartialState
from accelerate.utils import reduce from accelerate.utils import reduce
...@@ -490,18 +588,19 @@ def multi_scale_deformable_attention( ...@@ -490,18 +588,19 @@ def multi_scale_deformable_attention(
return output.transpose(1, 2).contiguous() return output.transpose(1, 2).contiguous()
# Copied from transformers.models.deformable_detr.modeling_deformable_detr.DeformableDetrMultiscaleDeformableAttention with DeformableDetr->Deta
class DetaMultiscaleDeformableAttention(nn.Module): class DetaMultiscaleDeformableAttention(nn.Module):
""" """
Multiscale deformable attention as proposed in Deformable DETR. Multiscale deformable attention as proposed in Deformable DETR.
""" """
def __init__(self, embed_dim: int, num_heads: int, n_levels: int, n_points: int): def __init__(self, config: DetaConfig, num_heads: int, n_points: int):
super().__init__() super().__init__()
if embed_dim % num_heads != 0: if config.d_model % num_heads != 0:
raise ValueError( raise ValueError(
f"embed_dim (d_model) must be divisible by num_heads, but got {embed_dim} and {num_heads}" f"embed_dim (d_model) must be divisible by num_heads, but got {config.d_model} and {num_heads}"
) )
dim_per_head = embed_dim // num_heads dim_per_head = config.d_model // num_heads
# check if dim_per_head is power of 2 # check if dim_per_head is power of 2
if not ((dim_per_head & (dim_per_head - 1) == 0) and dim_per_head != 0): if not ((dim_per_head & (dim_per_head - 1) == 0) and dim_per_head != 0):
warnings.warn( warnings.warn(
...@@ -512,15 +611,17 @@ class DetaMultiscaleDeformableAttention(nn.Module): ...@@ -512,15 +611,17 @@ class DetaMultiscaleDeformableAttention(nn.Module):
self.im2col_step = 64 self.im2col_step = 64
self.d_model = embed_dim self.d_model = config.d_model
self.n_levels = n_levels self.n_levels = config.num_feature_levels
self.n_heads = num_heads self.n_heads = num_heads
self.n_points = n_points self.n_points = n_points
self.sampling_offsets = nn.Linear(embed_dim, num_heads * n_levels * n_points * 2) self.sampling_offsets = nn.Linear(config.d_model, num_heads * self.n_levels * n_points * 2)
self.attention_weights = nn.Linear(embed_dim, num_heads * n_levels * n_points) self.attention_weights = nn.Linear(config.d_model, num_heads * self.n_levels * n_points)
self.value_proj = nn.Linear(embed_dim, embed_dim) self.value_proj = nn.Linear(config.d_model, config.d_model)
self.output_proj = nn.Linear(embed_dim, embed_dim) self.output_proj = nn.Linear(config.d_model, config.d_model)
self.disable_custom_kernels = config.disable_custom_kernels
self._reset_parameters() self._reset_parameters()
...@@ -598,7 +699,23 @@ class DetaMultiscaleDeformableAttention(nn.Module): ...@@ -598,7 +699,23 @@ class DetaMultiscaleDeformableAttention(nn.Module):
) )
else: else:
raise ValueError(f"Last dim of reference_points must be 2 or 4, but got {reference_points.shape[-1]}") raise ValueError(f"Last dim of reference_points must be 2 or 4, but got {reference_points.shape[-1]}")
# PyTorch implementation (for now)
if self.disable_custom_kernels:
# PyTorch implementation
output = multi_scale_deformable_attention(value, spatial_shapes, sampling_locations, attention_weights)
else:
try:
# custom kernel
output = MultiScaleDeformableAttentionFunction.apply(
value,
spatial_shapes,
level_start_index,
sampling_locations,
attention_weights,
self.im2col_step,
)
except Exception:
# PyTorch implementation
output = multi_scale_deformable_attention(value, spatial_shapes, sampling_locations, attention_weights) output = multi_scale_deformable_attention(value, spatial_shapes, sampling_locations, attention_weights)
output = self.output_proj(output) output = self.output_proj(output)
...@@ -728,9 +845,8 @@ class DetaEncoderLayer(nn.Module): ...@@ -728,9 +845,8 @@ class DetaEncoderLayer(nn.Module):
super().__init__() super().__init__()
self.embed_dim = config.d_model self.embed_dim = config.d_model
self.self_attn = DetaMultiscaleDeformableAttention( self.self_attn = DetaMultiscaleDeformableAttention(
embed_dim=self.embed_dim, config,
num_heads=config.encoder_attention_heads, num_heads=config.encoder_attention_heads,
n_levels=config.num_feature_levels,
n_points=config.encoder_n_points, n_points=config.encoder_n_points,
) )
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
...@@ -829,9 +945,8 @@ class DetaDecoderLayer(nn.Module): ...@@ -829,9 +945,8 @@ class DetaDecoderLayer(nn.Module):
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
# cross-attention # cross-attention
self.encoder_attn = DetaMultiscaleDeformableAttention( self.encoder_attn = DetaMultiscaleDeformableAttention(
embed_dim=self.embed_dim, config,
num_heads=config.decoder_attention_heads, num_heads=config.decoder_attention_heads,
n_levels=config.num_feature_levels,
n_points=config.decoder_n_points, n_points=config.decoder_n_points,
) )
self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim) self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment