Unverified Commit 83e96dc0 authored by Sangbum Daniel Choi's avatar Sangbum Daniel Choi Committed by GitHub
Browse files

Add cuda_custom_kernel in DETA (#28989)

* enable graident checkpointing in DetaObjectDetection

* fix missing part in original DETA

* make style

* make fix-copies

* Revert "make fix-copies"

This reverts commit 4041c86c29248f1673e8173b677c20b5a4511358.

* remove fix-copies of DetaDecoder

* enable swin gradient checkpointing

* fix gradient checkpointing in donut_swin

* add tests for deta/swin/donut

* Revert "fix gradient checkpointing in donut_swin"

This reverts commit 1cf345e34d3cc0e09eb800d9895805b1dd9b474d.

* change supports_gradient_checkpointing pipeline to PreTrainedModel

* Revert "add tests for deta/swin/donut"

This reverts commit 6056ffbb1eddc3cb3a99e4ebb231ae3edf295f5b.

* Revert "Revert "fix gradient checkpointing in donut_swin""

This reverts commit 24e25d0a14891241de58a0d86f817d0b5d2a341f.

* Simple revert

* enable deformable detr gradient checkpointing

* add gradient in encoder

* add cuda_custom_kernel function in MSDA

* make style and fix input of DetaMSDA

* make fix-copies

* remove n_levels in input of DetaMSDA

* minor changes

* refactor custom_cuda_kernel like yoso format
https://github.com/huggingface/transformers/blob/0507e69d34f8902422eb4977ec066dd6bef179a0/src/transformers/models/yoso/modeling_yoso.py#L53
parent f3788b09
/*!
**************************************************************************************************
* Deformable DETR
* Copyright (c) 2020 SenseTime. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
**************************************************************************************************
* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
**************************************************************************************************
*/
#include <vector>
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
at::Tensor
ms_deform_attn_cpu_forward(
const at::Tensor &value,
const at::Tensor &spatial_shapes,
const at::Tensor &level_start_index,
const at::Tensor &sampling_loc,
const at::Tensor &attn_weight,
const int im2col_step)
{
AT_ERROR("Not implement on cpu");
}
std::vector<at::Tensor>
ms_deform_attn_cpu_backward(
const at::Tensor &value,
const at::Tensor &spatial_shapes,
const at::Tensor &level_start_index,
const at::Tensor &sampling_loc,
const at::Tensor &attn_weight,
const at::Tensor &grad_output,
const int im2col_step)
{
AT_ERROR("Not implement on cpu");
}
/*!
**************************************************************************************************
* Deformable DETR
* Copyright (c) 2020 SenseTime. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
**************************************************************************************************
* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
**************************************************************************************************
*/
#pragma once
#include <torch/extension.h>
at::Tensor
ms_deform_attn_cpu_forward(
const at::Tensor &value,
const at::Tensor &spatial_shapes,
const at::Tensor &level_start_index,
const at::Tensor &sampling_loc,
const at::Tensor &attn_weight,
const int im2col_step);
std::vector<at::Tensor>
ms_deform_attn_cpu_backward(
const at::Tensor &value,
const at::Tensor &spatial_shapes,
const at::Tensor &level_start_index,
const at::Tensor &sampling_loc,
const at::Tensor &attn_weight,
const at::Tensor &grad_output,
const int im2col_step);
/*!
**************************************************************************************************
* Deformable DETR
* Copyright (c) 2020 SenseTime. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
**************************************************************************************************
* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
**************************************************************************************************
*/
#include <vector>
#include "cuda/ms_deform_im2col_cuda.cuh"
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <cuda.h>
#include <cuda_runtime.h>
#pragma once
#include <torch/extension.h>
at::Tensor ms_deform_attn_cuda_forward(
const at::Tensor &value,
const at::Tensor &spatial_shapes,
const at::Tensor &level_start_index,
const at::Tensor &sampling_loc,
const at::Tensor &attn_weight,
const int im2col_step)
{
AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous");
AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous");
AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous");
AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous");
AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous");
AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor");
AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor");
AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor");
AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor");
AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor");
const int batch = value.size(0);
const int spatial_size = value.size(1);
const int num_heads = value.size(2);
const int channels = value.size(3);
const int num_levels = spatial_shapes.size(0);
const int num_query = sampling_loc.size(1);
const int num_point = sampling_loc.size(4);
const int im2col_step_ = std::min(batch, im2col_step);
AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_);
auto output = at::zeros({batch, num_query, num_heads, channels}, value.options());
const int batch_n = im2col_step_;
auto output_n = output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels});
auto per_value_size = spatial_size * num_heads * channels;
auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2;
auto per_attn_weight_size = num_query * num_heads * num_levels * num_point;
for (int n = 0; n < batch/im2col_step_; ++n)
{
auto columns = output_n.select(0, n);
AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_forward_cuda", ([&] {
ms_deformable_im2col_cuda(at::cuda::getCurrentCUDAStream(),
value.data<scalar_t>() + n * im2col_step_ * per_value_size,
spatial_shapes.data<int64_t>(),
level_start_index.data<int64_t>(),
sampling_loc.data<scalar_t>() + n * im2col_step_ * per_sample_loc_size,
attn_weight.data<scalar_t>() + n * im2col_step_ * per_attn_weight_size,
batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point,
columns.data<scalar_t>());
}));
}
output = output.view({batch, num_query, num_heads*channels});
return output;
}
std::vector<at::Tensor> ms_deform_attn_cuda_backward(
const at::Tensor &value,
const at::Tensor &spatial_shapes,
const at::Tensor &level_start_index,
const at::Tensor &sampling_loc,
const at::Tensor &attn_weight,
const at::Tensor &grad_output,
const int im2col_step)
{
AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous");
AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous");
AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous");
AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous");
AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous");
AT_ASSERTM(grad_output.is_contiguous(), "grad_output tensor has to be contiguous");
AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor");
AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor");
AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor");
AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor");
AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor");
AT_ASSERTM(grad_output.type().is_cuda(), "grad_output must be a CUDA tensor");
const int batch = value.size(0);
const int spatial_size = value.size(1);
const int num_heads = value.size(2);
const int channels = value.size(3);
const int num_levels = spatial_shapes.size(0);
const int num_query = sampling_loc.size(1);
const int num_point = sampling_loc.size(4);
const int im2col_step_ = std::min(batch, im2col_step);
AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_);
auto grad_value = at::zeros_like(value);
auto grad_sampling_loc = at::zeros_like(sampling_loc);
auto grad_attn_weight = at::zeros_like(attn_weight);
const int batch_n = im2col_step_;
auto per_value_size = spatial_size * num_heads * channels;
auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2;
auto per_attn_weight_size = num_query * num_heads * num_levels * num_point;
auto grad_output_n = grad_output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels});
for (int n = 0; n < batch/im2col_step_; ++n)
{
auto grad_output_g = grad_output_n.select(0, n);
AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_backward_cuda", ([&] {
ms_deformable_col2im_cuda(at::cuda::getCurrentCUDAStream(),
grad_output_g.data<scalar_t>(),
value.data<scalar_t>() + n * im2col_step_ * per_value_size,
spatial_shapes.data<int64_t>(),
level_start_index.data<int64_t>(),
sampling_loc.data<scalar_t>() + n * im2col_step_ * per_sample_loc_size,
attn_weight.data<scalar_t>() + n * im2col_step_ * per_attn_weight_size,
batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point,
grad_value.data<scalar_t>() + n * im2col_step_ * per_value_size,
grad_sampling_loc.data<scalar_t>() + n * im2col_step_ * per_sample_loc_size,
grad_attn_weight.data<scalar_t>() + n * im2col_step_ * per_attn_weight_size);
}));
}
return {
grad_value, grad_sampling_loc, grad_attn_weight
};
}
This diff is collapsed.
/*!
**************************************************************************************************
* Deformable DETR
* Copyright (c) 2020 SenseTime. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
**************************************************************************************************
* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
**************************************************************************************************
*/
#pragma once
#include <torch/extension.h>
at::Tensor ms_deform_attn_cuda_forward(
const at::Tensor &value,
const at::Tensor &spatial_shapes,
const at::Tensor &level_start_index,
const at::Tensor &sampling_loc,
const at::Tensor &attn_weight,
const int im2col_step);
std::vector<at::Tensor> ms_deform_attn_cuda_backward(
const at::Tensor &value,
const at::Tensor &spatial_shapes,
const at::Tensor &level_start_index,
const at::Tensor &sampling_loc,
const at::Tensor &attn_weight,
const at::Tensor &grad_output,
const int im2col_step);
This diff is collapsed.
/*!
**************************************************************************************************
* Deformable DETR
* Copyright (c) 2020 SenseTime. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
**************************************************************************************************
* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
**************************************************************************************************
*/
#pragma once
#include "cpu/ms_deform_attn_cpu.h"
#ifdef WITH_CUDA
#include "cuda/ms_deform_attn_cuda.h"
#endif
at::Tensor
ms_deform_attn_forward(
const at::Tensor &value,
const at::Tensor &spatial_shapes,
const at::Tensor &level_start_index,
const at::Tensor &sampling_loc,
const at::Tensor &attn_weight,
const int im2col_step)
{
if (value.type().is_cuda())
{
#ifdef WITH_CUDA
return ms_deform_attn_cuda_forward(
value, spatial_shapes, level_start_index, sampling_loc, attn_weight, im2col_step);
#else
AT_ERROR("Not compiled with GPU support");
#endif
}
AT_ERROR("Not implemented on the CPU");
}
std::vector<at::Tensor>
ms_deform_attn_backward(
const at::Tensor &value,
const at::Tensor &spatial_shapes,
const at::Tensor &level_start_index,
const at::Tensor &sampling_loc,
const at::Tensor &attn_weight,
const at::Tensor &grad_output,
const int im2col_step)
{
if (value.type().is_cuda())
{
#ifdef WITH_CUDA
return ms_deform_attn_cuda_backward(
value, spatial_shapes, level_start_index, sampling_loc, attn_weight, grad_output, im2col_step);
#else
AT_ERROR("Not compiled with GPU support");
#endif
}
AT_ERROR("Not implemented on the CPU");
}
/*!
**************************************************************************************************
* Deformable DETR
* Copyright (c) 2020 SenseTime. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
**************************************************************************************************
* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
**************************************************************************************************
*/
#include "ms_deform_attn.h"
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("ms_deform_attn_forward", &ms_deform_attn_forward, "ms_deform_attn_forward");
m.def("ms_deform_attn_backward", &ms_deform_attn_backward, "ms_deform_attn_backward");
}
\ No newline at end of file
......@@ -125,6 +125,9 @@ class DetaConfig(PretrainedConfig):
Whether to assign each prediction i to the highest overlapping ground truth object if the overlap is larger than a threshold 0.7.
assign_second_stage (`bool`, *optional*, defaults to `True`):
Whether to assign second assignment procedure in the second stage closely follows the first stage assignment procedure.
disable_custom_kernels (`bool`, *optional*, defaults to `True`):
Disable the use of custom CUDA and CPU kernels. This option is necessary for the ONNX export, as custom
kernels are not supported by PyTorch ONNX export.
Examples:
......@@ -191,6 +194,7 @@ class DetaConfig(PretrainedConfig):
giou_loss_coefficient=2,
eos_coefficient=0.1,
focal_alpha=0.25,
disable_custom_kernels=True,
**kwargs,
):
if use_pretrained_backbone:
......@@ -256,6 +260,7 @@ class DetaConfig(PretrainedConfig):
self.giou_loss_coefficient = giou_loss_coefficient
self.eos_coefficient = eos_coefficient
self.focal_alpha = focal_alpha
self.disable_custom_kernels = disable_custom_kernels
super().__init__(is_encoder_decoder=is_encoder_decoder, **kwargs)
@property
......
......@@ -17,13 +17,17 @@
import copy
import math
import os
import warnings
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Union
import torch
import torch.nn.functional as F
from torch import Tensor, nn
from torch.autograd import Function
from torch.autograd.function import once_differentiable
from ...activations import ACT2FN
from ...file_utils import (
......@@ -31,6 +35,7 @@ from ...file_utils import (
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_scipy_available,
is_torch_cuda_available,
is_vision_available,
replace_return_docstrings,
)
......@@ -38,7 +43,7 @@ from ...modeling_attn_mask_utils import _prepare_4d_attention_mask
from ...modeling_outputs import BaseModelOutput
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import meshgrid
from ...utils import is_accelerate_available, is_torchvision_available, logging, requires_backends
from ...utils import is_accelerate_available, is_ninja_available, is_torchvision_available, logging, requires_backends
from ...utils.backbone_utils import load_backbone
from .configuration_deta import DetaConfig
......@@ -46,6 +51,99 @@ from .configuration_deta import DetaConfig
logger = logging.get_logger(__name__)
def load_cuda_kernels():
from torch.utils.cpp_extension import load
root = Path(__file__).resolve().parent.parent.parent / "kernels" / "deta"
src_files = [
root / filename
for filename in [
"vision.cpp",
os.path.join("cpu", "ms_deform_attn_cpu.cpp"),
os.path.join("cuda", "ms_deform_attn_cuda.cu"),
]
]
load(
"MultiScaleDeformableAttention",
src_files,
with_cuda=True,
extra_include_paths=[str(root)],
extra_cflags=["-DWITH_CUDA=1"],
extra_cuda_cflags=[
"-DCUDA_HAS_FP16=1",
"-D__CUDA_NO_HALF_OPERATORS__",
"-D__CUDA_NO_HALF_CONVERSIONS__",
"-D__CUDA_NO_HALF2_OPERATORS__",
],
)
import MultiScaleDeformableAttention as MSDA
return MSDA
# Move this to not compile only when importing, this needs to happen later, like in __init__.
if is_torch_cuda_available() and is_ninja_available():
logger.info("Loading custom CUDA kernels...")
try:
MultiScaleDeformableAttention = load_cuda_kernels()
except Exception as e:
logger.warning(f"Could not load the custom kernel for multi-scale deformable attention: {e}")
MultiScaleDeformableAttention = None
else:
MultiScaleDeformableAttention = None
# Copied from transformers.models.deformable_detr.modeling_deformable_detr.MultiScaleDeformableAttentionFunction
class MultiScaleDeformableAttentionFunction(Function):
@staticmethod
def forward(
context,
value,
value_spatial_shapes,
value_level_start_index,
sampling_locations,
attention_weights,
im2col_step,
):
context.im2col_step = im2col_step
output = MultiScaleDeformableAttention.ms_deform_attn_forward(
value,
value_spatial_shapes,
value_level_start_index,
sampling_locations,
attention_weights,
context.im2col_step,
)
context.save_for_backward(
value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights
)
return output
@staticmethod
@once_differentiable
def backward(context, grad_output):
(
value,
value_spatial_shapes,
value_level_start_index,
sampling_locations,
attention_weights,
) = context.saved_tensors
grad_value, grad_sampling_loc, grad_attn_weight = MultiScaleDeformableAttention.ms_deform_attn_backward(
value,
value_spatial_shapes,
value_level_start_index,
sampling_locations,
attention_weights,
grad_output,
context.im2col_step,
)
return grad_value, None, None, grad_sampling_loc, grad_attn_weight, None
if is_accelerate_available():
from accelerate import PartialState
from accelerate.utils import reduce
......@@ -490,18 +588,19 @@ def multi_scale_deformable_attention(
return output.transpose(1, 2).contiguous()
# Copied from transformers.models.deformable_detr.modeling_deformable_detr.DeformableDetrMultiscaleDeformableAttention with DeformableDetr->Deta
class DetaMultiscaleDeformableAttention(nn.Module):
"""
Multiscale deformable attention as proposed in Deformable DETR.
"""
def __init__(self, embed_dim: int, num_heads: int, n_levels: int, n_points: int):
def __init__(self, config: DetaConfig, num_heads: int, n_points: int):
super().__init__()
if embed_dim % num_heads != 0:
if config.d_model % num_heads != 0:
raise ValueError(
f"embed_dim (d_model) must be divisible by num_heads, but got {embed_dim} and {num_heads}"
f"embed_dim (d_model) must be divisible by num_heads, but got {config.d_model} and {num_heads}"
)
dim_per_head = embed_dim // num_heads
dim_per_head = config.d_model // num_heads
# check if dim_per_head is power of 2
if not ((dim_per_head & (dim_per_head - 1) == 0) and dim_per_head != 0):
warnings.warn(
......@@ -512,15 +611,17 @@ class DetaMultiscaleDeformableAttention(nn.Module):
self.im2col_step = 64
self.d_model = embed_dim
self.n_levels = n_levels
self.d_model = config.d_model
self.n_levels = config.num_feature_levels
self.n_heads = num_heads
self.n_points = n_points
self.sampling_offsets = nn.Linear(embed_dim, num_heads * n_levels * n_points * 2)
self.attention_weights = nn.Linear(embed_dim, num_heads * n_levels * n_points)
self.value_proj = nn.Linear(embed_dim, embed_dim)
self.output_proj = nn.Linear(embed_dim, embed_dim)
self.sampling_offsets = nn.Linear(config.d_model, num_heads * self.n_levels * n_points * 2)
self.attention_weights = nn.Linear(config.d_model, num_heads * self.n_levels * n_points)
self.value_proj = nn.Linear(config.d_model, config.d_model)
self.output_proj = nn.Linear(config.d_model, config.d_model)
self.disable_custom_kernels = config.disable_custom_kernels
self._reset_parameters()
......@@ -598,7 +699,23 @@ class DetaMultiscaleDeformableAttention(nn.Module):
)
else:
raise ValueError(f"Last dim of reference_points must be 2 or 4, but got {reference_points.shape[-1]}")
# PyTorch implementation (for now)
if self.disable_custom_kernels:
# PyTorch implementation
output = multi_scale_deformable_attention(value, spatial_shapes, sampling_locations, attention_weights)
else:
try:
# custom kernel
output = MultiScaleDeformableAttentionFunction.apply(
value,
spatial_shapes,
level_start_index,
sampling_locations,
attention_weights,
self.im2col_step,
)
except Exception:
# PyTorch implementation
output = multi_scale_deformable_attention(value, spatial_shapes, sampling_locations, attention_weights)
output = self.output_proj(output)
......@@ -728,9 +845,8 @@ class DetaEncoderLayer(nn.Module):
super().__init__()
self.embed_dim = config.d_model
self.self_attn = DetaMultiscaleDeformableAttention(
embed_dim=self.embed_dim,
config,
num_heads=config.encoder_attention_heads,
n_levels=config.num_feature_levels,
n_points=config.encoder_n_points,
)
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
......@@ -829,9 +945,8 @@ class DetaDecoderLayer(nn.Module):
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
# cross-attention
self.encoder_attn = DetaMultiscaleDeformableAttention(
embed_dim=self.embed_dim,
config,
num_heads=config.decoder_attention_heads,
n_levels=config.num_feature_levels,
n_points=config.decoder_n_points,
)
self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment