Unverified Commit edb6d950 authored by fxmarty's avatar fxmarty Committed by GitHub
Browse files

Add an attribute to disable custom kernels in deformable detr in order to make...


Add an attribute to disable custom kernels in deformable detr in order to make the model ONNX exportable (#22918)

* add disable kernel option

* add comment

* fix copies

* add disable_custom_kernels to config

* Update src/transformers/models/deta/modeling_deta.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update src/transformers/models/deta/modeling_deta.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update src/transformers/models/deta/modeling_deta.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* style

* fix

---------
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>
parent 84097f6d
......@@ -125,6 +125,9 @@ class DeformableDetrConfig(PretrainedConfig):
based on the predictions from the previous layer.
focal_alpha (`float`, *optional*, defaults to 0.25):
Alpha parameter in the focal loss.
disable_custom_kernels (`bool`, *optional*, defaults to `False`):
Disable the use of custom CUDA and CPU kernels. This option is necessary for the ONNX export, as custom
kernels are not supported by PyTorch ONNX export.
Examples:
......@@ -189,6 +192,7 @@ class DeformableDetrConfig(PretrainedConfig):
giou_loss_coefficient=2,
eos_coefficient=0.1,
focal_alpha=0.25,
disable_custom_kernels=False,
**kwargs,
):
if backbone_config is not None and use_timm_backbone:
......@@ -246,6 +250,7 @@ class DeformableDetrConfig(PretrainedConfig):
self.giou_loss_coefficient = giou_loss_coefficient
self.eos_coefficient = eos_coefficient
self.focal_alpha = focal_alpha
self.disable_custom_kernels = disable_custom_kernels
super().__init__(is_encoder_decoder=is_encoder_decoder, **kwargs)
@property
......
......@@ -589,13 +589,13 @@ class DeformableDetrMultiscaleDeformableAttention(nn.Module):
Multiscale deformable attention as proposed in Deformable DETR.
"""
def __init__(self, embed_dim: int, num_heads: int, n_levels: int, n_points: int):
def __init__(self, config: DeformableDetrConfig, num_heads: int, n_points: int):
super().__init__()
if embed_dim % num_heads != 0:
if config.d_model % num_heads != 0:
raise ValueError(
f"embed_dim (d_model) must be divisible by num_heads, but got {embed_dim} and {num_heads}"
f"embed_dim (d_model) must be divisible by num_heads, but got {config.d_model} and {num_heads}"
)
dim_per_head = embed_dim // num_heads
dim_per_head = config.d_model // num_heads
# check if dim_per_head is power of 2
if not ((dim_per_head & (dim_per_head - 1) == 0) and dim_per_head != 0):
warnings.warn(
......@@ -606,15 +606,17 @@ class DeformableDetrMultiscaleDeformableAttention(nn.Module):
self.im2col_step = 64
self.d_model = embed_dim
self.n_levels = n_levels
self.d_model = config.d_model
self.n_levels = config.num_feature_levels
self.n_heads = num_heads
self.n_points = n_points
self.sampling_offsets = nn.Linear(embed_dim, num_heads * n_levels * n_points * 2)
self.attention_weights = nn.Linear(embed_dim, num_heads * n_levels * n_points)
self.value_proj = nn.Linear(embed_dim, embed_dim)
self.output_proj = nn.Linear(embed_dim, embed_dim)
self.sampling_offsets = nn.Linear(config.d_model, num_heads * self.n_levels * n_points * 2)
self.attention_weights = nn.Linear(config.d_model, num_heads * self.n_levels * n_points)
self.value_proj = nn.Linear(config.d_model, config.d_model)
self.output_proj = nn.Linear(config.d_model, config.d_model)
self.disable_custom_kernels = config.disable_custom_kernels
self._reset_parameters()
......@@ -692,6 +694,11 @@ class DeformableDetrMultiscaleDeformableAttention(nn.Module):
)
else:
raise ValueError(f"Last dim of reference_points must be 2 or 4, but got {reference_points.shape[-1]}")
if self.disable_custom_kernels:
# PyTorch implementation
output = multi_scale_deformable_attention(value, spatial_shapes, sampling_locations, attention_weights)
else:
try:
# custom kernel
output = MultiScaleDeformableAttentionFunction.apply(
......@@ -832,10 +839,7 @@ class DeformableDetrEncoderLayer(nn.Module):
super().__init__()
self.embed_dim = config.d_model
self.self_attn = DeformableDetrMultiscaleDeformableAttention(
embed_dim=self.embed_dim,
num_heads=config.encoder_attention_heads,
n_levels=config.num_feature_levels,
n_points=config.encoder_n_points,
config, num_heads=config.encoder_attention_heads, n_points=config.encoder_n_points
)
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
self.dropout = config.dropout
......@@ -933,9 +937,8 @@ class DeformableDetrDecoderLayer(nn.Module):
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
# cross-attention
self.encoder_attn = DeformableDetrMultiscaleDeformableAttention(
embed_dim=self.embed_dim,
config,
num_heads=config.decoder_attention_heads,
n_levels=config.num_feature_levels,
n_points=config.decoder_n_points,
)
self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim)
......
......@@ -492,7 +492,6 @@ class DetaMultiscaleDeformableAttention(nn.Module):
Multiscale deformable attention as proposed in Deformable DETR.
"""
# Copied from transformers.models.deformable_detr.modeling_deformable_detr.DeformableDetrMultiscaleDeformableAttention.__init__ with DeformableDetr->Deta
def __init__(self, embed_dim: int, num_heads: int, n_levels: int, n_points: int):
super().__init__()
if embed_dim % num_heads != 0:
......@@ -721,7 +720,6 @@ class DetaMultiheadAttention(nn.Module):
return attn_output, attn_weights_reshaped
# Copied from transformers.models.deformable_detr.modeling_deformable_detr.DeformableDetrEncoderLayer with DeformableDetr->Deta
class DetaEncoderLayer(nn.Module):
def __init__(self, config: DetaConfig):
super().__init__()
......@@ -810,7 +808,6 @@ class DetaEncoderLayer(nn.Module):
return outputs
# Copied from transformers.models.deformable_detr.modeling_deformable_detr.DeformableDetrDecoderLayer with DeformableDetr->Deta
class DetaDecoderLayer(nn.Module):
def __init__(self, config: DetaConfig):
super().__init__()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment