Unverified Commit edb6d950 authored by fxmarty's avatar fxmarty Committed by GitHub
Browse files

Add an attribute to disable custom kernels in deformable detr in order to make...


Add an attribute to disable custom kernels in deformable detr in order to make the model ONNX exportable (#22918)

* add disable kernel option

* add comment

* fix copies

* add disable_custom_kernels to config

* Update src/transformers/models/deta/modeling_deta.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update src/transformers/models/deta/modeling_deta.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update src/transformers/models/deta/modeling_deta.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* style

* fix

---------
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>
parent 84097f6d
...@@ -125,6 +125,9 @@ class DeformableDetrConfig(PretrainedConfig): ...@@ -125,6 +125,9 @@ class DeformableDetrConfig(PretrainedConfig):
based on the predictions from the previous layer. based on the predictions from the previous layer.
focal_alpha (`float`, *optional*, defaults to 0.25): focal_alpha (`float`, *optional*, defaults to 0.25):
Alpha parameter in the focal loss. Alpha parameter in the focal loss.
disable_custom_kernels (`bool`, *optional*, defaults to `False`):
Disable the use of custom CUDA and CPU kernels. This option is necessary for the ONNX export, as custom
kernels are not supported by PyTorch ONNX export.
Examples: Examples:
...@@ -189,6 +192,7 @@ class DeformableDetrConfig(PretrainedConfig): ...@@ -189,6 +192,7 @@ class DeformableDetrConfig(PretrainedConfig):
giou_loss_coefficient=2, giou_loss_coefficient=2,
eos_coefficient=0.1, eos_coefficient=0.1,
focal_alpha=0.25, focal_alpha=0.25,
disable_custom_kernels=False,
**kwargs, **kwargs,
): ):
if backbone_config is not None and use_timm_backbone: if backbone_config is not None and use_timm_backbone:
...@@ -246,6 +250,7 @@ class DeformableDetrConfig(PretrainedConfig): ...@@ -246,6 +250,7 @@ class DeformableDetrConfig(PretrainedConfig):
self.giou_loss_coefficient = giou_loss_coefficient self.giou_loss_coefficient = giou_loss_coefficient
self.eos_coefficient = eos_coefficient self.eos_coefficient = eos_coefficient
self.focal_alpha = focal_alpha self.focal_alpha = focal_alpha
self.disable_custom_kernels = disable_custom_kernels
super().__init__(is_encoder_decoder=is_encoder_decoder, **kwargs) super().__init__(is_encoder_decoder=is_encoder_decoder, **kwargs)
@property @property
......
...@@ -589,13 +589,13 @@ class DeformableDetrMultiscaleDeformableAttention(nn.Module): ...@@ -589,13 +589,13 @@ class DeformableDetrMultiscaleDeformableAttention(nn.Module):
Multiscale deformable attention as proposed in Deformable DETR. Multiscale deformable attention as proposed in Deformable DETR.
""" """
def __init__(self, embed_dim: int, num_heads: int, n_levels: int, n_points: int): def __init__(self, config: DeformableDetrConfig, num_heads: int, n_points: int):
super().__init__() super().__init__()
if embed_dim % num_heads != 0: if config.d_model % num_heads != 0:
raise ValueError( raise ValueError(
f"embed_dim (d_model) must be divisible by num_heads, but got {embed_dim} and {num_heads}" f"embed_dim (d_model) must be divisible by num_heads, but got {config.d_model} and {num_heads}"
) )
dim_per_head = embed_dim // num_heads dim_per_head = config.d_model // num_heads
# check if dim_per_head is power of 2 # check if dim_per_head is power of 2
if not ((dim_per_head & (dim_per_head - 1) == 0) and dim_per_head != 0): if not ((dim_per_head & (dim_per_head - 1) == 0) and dim_per_head != 0):
warnings.warn( warnings.warn(
...@@ -606,15 +606,17 @@ class DeformableDetrMultiscaleDeformableAttention(nn.Module): ...@@ -606,15 +606,17 @@ class DeformableDetrMultiscaleDeformableAttention(nn.Module):
self.im2col_step = 64 self.im2col_step = 64
self.d_model = embed_dim self.d_model = config.d_model
self.n_levels = n_levels self.n_levels = config.num_feature_levels
self.n_heads = num_heads self.n_heads = num_heads
self.n_points = n_points self.n_points = n_points
self.sampling_offsets = nn.Linear(embed_dim, num_heads * n_levels * n_points * 2) self.sampling_offsets = nn.Linear(config.d_model, num_heads * self.n_levels * n_points * 2)
self.attention_weights = nn.Linear(embed_dim, num_heads * n_levels * n_points) self.attention_weights = nn.Linear(config.d_model, num_heads * self.n_levels * n_points)
self.value_proj = nn.Linear(embed_dim, embed_dim) self.value_proj = nn.Linear(config.d_model, config.d_model)
self.output_proj = nn.Linear(embed_dim, embed_dim) self.output_proj = nn.Linear(config.d_model, config.d_model)
self.disable_custom_kernels = config.disable_custom_kernels
self._reset_parameters() self._reset_parameters()
...@@ -692,19 +694,24 @@ class DeformableDetrMultiscaleDeformableAttention(nn.Module): ...@@ -692,19 +694,24 @@ class DeformableDetrMultiscaleDeformableAttention(nn.Module):
) )
else: else:
raise ValueError(f"Last dim of reference_points must be 2 or 4, but got {reference_points.shape[-1]}") raise ValueError(f"Last dim of reference_points must be 2 or 4, but got {reference_points.shape[-1]}")
try:
# custom kernel if self.disable_custom_kernels:
output = MultiScaleDeformableAttentionFunction.apply(
value,
spatial_shapes,
level_start_index,
sampling_locations,
attention_weights,
self.im2col_step,
)
except Exception:
# PyTorch implementation # PyTorch implementation
output = multi_scale_deformable_attention(value, spatial_shapes, sampling_locations, attention_weights) output = multi_scale_deformable_attention(value, spatial_shapes, sampling_locations, attention_weights)
else:
try:
# custom kernel
output = MultiScaleDeformableAttentionFunction.apply(
value,
spatial_shapes,
level_start_index,
sampling_locations,
attention_weights,
self.im2col_step,
)
except Exception:
# PyTorch implementation
output = multi_scale_deformable_attention(value, spatial_shapes, sampling_locations, attention_weights)
output = self.output_proj(output) output = self.output_proj(output)
return output, attention_weights return output, attention_weights
...@@ -832,10 +839,7 @@ class DeformableDetrEncoderLayer(nn.Module): ...@@ -832,10 +839,7 @@ class DeformableDetrEncoderLayer(nn.Module):
super().__init__() super().__init__()
self.embed_dim = config.d_model self.embed_dim = config.d_model
self.self_attn = DeformableDetrMultiscaleDeformableAttention( self.self_attn = DeformableDetrMultiscaleDeformableAttention(
embed_dim=self.embed_dim, config, num_heads=config.encoder_attention_heads, n_points=config.encoder_n_points
num_heads=config.encoder_attention_heads,
n_levels=config.num_feature_levels,
n_points=config.encoder_n_points,
) )
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
self.dropout = config.dropout self.dropout = config.dropout
...@@ -933,9 +937,8 @@ class DeformableDetrDecoderLayer(nn.Module): ...@@ -933,9 +937,8 @@ class DeformableDetrDecoderLayer(nn.Module):
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
# cross-attention # cross-attention
self.encoder_attn = DeformableDetrMultiscaleDeformableAttention( self.encoder_attn = DeformableDetrMultiscaleDeformableAttention(
embed_dim=self.embed_dim, config,
num_heads=config.decoder_attention_heads, num_heads=config.decoder_attention_heads,
n_levels=config.num_feature_levels,
n_points=config.decoder_n_points, n_points=config.decoder_n_points,
) )
self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim) self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim)
......
...@@ -492,7 +492,6 @@ class DetaMultiscaleDeformableAttention(nn.Module): ...@@ -492,7 +492,6 @@ class DetaMultiscaleDeformableAttention(nn.Module):
Multiscale deformable attention as proposed in Deformable DETR. Multiscale deformable attention as proposed in Deformable DETR.
""" """
# Copied from transformers.models.deformable_detr.modeling_deformable_detr.DeformableDetrMultiscaleDeformableAttention.__init__ with DeformableDetr->Deta
def __init__(self, embed_dim: int, num_heads: int, n_levels: int, n_points: int): def __init__(self, embed_dim: int, num_heads: int, n_levels: int, n_points: int):
super().__init__() super().__init__()
if embed_dim % num_heads != 0: if embed_dim % num_heads != 0:
...@@ -721,7 +720,6 @@ class DetaMultiheadAttention(nn.Module): ...@@ -721,7 +720,6 @@ class DetaMultiheadAttention(nn.Module):
return attn_output, attn_weights_reshaped return attn_output, attn_weights_reshaped
# Copied from transformers.models.deformable_detr.modeling_deformable_detr.DeformableDetrEncoderLayer with DeformableDetr->Deta
class DetaEncoderLayer(nn.Module): class DetaEncoderLayer(nn.Module):
def __init__(self, config: DetaConfig): def __init__(self, config: DetaConfig):
super().__init__() super().__init__()
...@@ -810,7 +808,6 @@ class DetaEncoderLayer(nn.Module): ...@@ -810,7 +808,6 @@ class DetaEncoderLayer(nn.Module):
return outputs return outputs
# Copied from transformers.models.deformable_detr.modeling_deformable_detr.DeformableDetrDecoderLayer with DeformableDetr->Deta
class DetaDecoderLayer(nn.Module): class DetaDecoderLayer(nn.Module):
def __init__(self, config: DetaConfig): def __init__(self, config: DetaConfig):
super().__init__() super().__init__()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment