Unverified Commit 8f198e53 authored by Aditya Oke's avatar Aditya Oke Committed by GitHub
Browse files

Fix dropout issue in swin transformers (#7224)


Co-authored-by: default avatarNicolas Hug <contact@nicolas-hug.com>
parent 55d3ba62
......@@ -126,7 +126,8 @@ def shifted_window_attention(
qkv_bias: Optional[Tensor] = None,
proj_bias: Optional[Tensor] = None,
logit_scale: Optional[torch.Tensor] = None,
):
training: bool = True,
) -> Tensor:
"""
Window based multi-head self attention (W-MSA) module with relative position bias.
It supports both of shifted and non-shifted window.
......@@ -143,6 +144,7 @@ def shifted_window_attention(
qkv_bias (Tensor[out_dim], optional): The bias tensor of query, key, value. Default: None.
proj_bias (Tensor[out_dim], optional): The bias tensor of projection. Default: None.
logit_scale (Tensor[out_dim], optional): Logit scale of cosine attention for Swin Transformer V2. Default: None.
training (bool, optional): Training flag used by the dropout parameters. Default: True.
Returns:
Tensor[N, H, W, C]: The output tensor after shifted window attention.
"""
......@@ -207,11 +209,11 @@ def shifted_window_attention(
attn = attn.view(-1, num_heads, x.size(1), x.size(1))
attn = F.softmax(attn, dim=-1)
attn = F.dropout(attn, p=attention_dropout)
attn = F.dropout(attn, p=attention_dropout, training=training)
x = attn.matmul(v).transpose(1, 2).reshape(x.size(0), x.size(1), C)
x = F.linear(x, proj_weight, proj_bias)
x = F.dropout(x, p=dropout)
x = F.dropout(x, p=dropout, training=training)
# reverse windows
x = x.view(B, pad_H // window_size[0], pad_W // window_size[1], window_size[0], window_size[1], C)
......@@ -286,7 +288,7 @@ class ShiftedWindowAttention(nn.Module):
self.relative_position_bias_table, self.relative_position_index, self.window_size # type: ignore[arg-type]
)
def forward(self, x: Tensor):
def forward(self, x: Tensor) -> Tensor:
"""
Args:
x (Tensor): Tensor with layout of [B, H, W, C]
......@@ -306,6 +308,7 @@ class ShiftedWindowAttention(nn.Module):
dropout=self.dropout,
qkv_bias=self.qkv.bias,
proj_bias=self.proj.bias,
training=self.training,
)
......@@ -391,6 +394,7 @@ class ShiftedWindowAttentionV2(ShiftedWindowAttention):
qkv_bias=self.qkv.bias,
proj_bias=self.proj.bias,
logit_scale=self.logit_scale,
training=self.training,
)
......
......@@ -124,6 +124,7 @@ def shifted_window_attention_3d(
dropout: float = 0.0,
qkv_bias: Optional[Tensor] = None,
proj_bias: Optional[Tensor] = None,
training: bool = True,
) -> Tensor:
"""
Window based multi-head self attention (W-MSA) module with relative position bias.
......@@ -140,6 +141,7 @@ def shifted_window_attention_3d(
dropout (float): Dropout ratio of output. Default: 0.0.
qkv_bias (Tensor[out_dim], optional): The bias tensor of query, key, value. Default: None.
proj_bias (Tensor[out_dim], optional): The bias tensor of projection. Default: None.
training (bool, optional): Training flag used by the dropout parameters. Default: True.
Returns:
Tensor[B, T, H, W, C]: The output tensor after shifted window attention.
"""
......@@ -194,11 +196,11 @@ def shifted_window_attention_3d(
attn = attn.view(-1, num_heads, x.size(1), x.size(1))
attn = F.softmax(attn, dim=-1)
attn = F.dropout(attn, p=attention_dropout)
attn = F.dropout(attn, p=attention_dropout, training=training)
x = attn.matmul(v).transpose(1, 2).reshape(x.size(0), x.size(1), c)
x = F.linear(x, proj_weight, proj_bias)
x = F.dropout(x, p=dropout)
x = F.dropout(x, p=dropout, training=training)
# reverse windows
x = x.view(
......@@ -310,6 +312,7 @@ class ShiftedWindowAttention3d(nn.Module):
dropout=self.dropout,
qkv_bias=self.qkv.bias,
proj_bias=self.proj.bias,
training=self.training,
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment