Unverified Commit 8f198e53 authored by Aditya Oke's avatar Aditya Oke Committed by GitHub
Browse files

Fix dropout issue in swin transformers (#7224)


Co-authored-by: default avatarNicolas Hug <contact@nicolas-hug.com>
parent 55d3ba62
...@@ -126,7 +126,8 @@ def shifted_window_attention( ...@@ -126,7 +126,8 @@ def shifted_window_attention(
qkv_bias: Optional[Tensor] = None, qkv_bias: Optional[Tensor] = None,
proj_bias: Optional[Tensor] = None, proj_bias: Optional[Tensor] = None,
logit_scale: Optional[torch.Tensor] = None, logit_scale: Optional[torch.Tensor] = None,
): training: bool = True,
) -> Tensor:
""" """
Window based multi-head self attention (W-MSA) module with relative position bias. Window based multi-head self attention (W-MSA) module with relative position bias.
It supports both of shifted and non-shifted window. It supports both of shifted and non-shifted window.
...@@ -143,6 +144,7 @@ def shifted_window_attention( ...@@ -143,6 +144,7 @@ def shifted_window_attention(
qkv_bias (Tensor[out_dim], optional): The bias tensor of query, key, value. Default: None. qkv_bias (Tensor[out_dim], optional): The bias tensor of query, key, value. Default: None.
proj_bias (Tensor[out_dim], optional): The bias tensor of projection. Default: None. proj_bias (Tensor[out_dim], optional): The bias tensor of projection. Default: None.
logit_scale (Tensor[out_dim], optional): Logit scale of cosine attention for Swin Transformer V2. Default: None. logit_scale (Tensor[out_dim], optional): Logit scale of cosine attention for Swin Transformer V2. Default: None.
training (bool, optional): Training flag used by the dropout parameters. Default: True.
Returns: Returns:
Tensor[N, H, W, C]: The output tensor after shifted window attention. Tensor[N, H, W, C]: The output tensor after shifted window attention.
""" """
...@@ -207,11 +209,11 @@ def shifted_window_attention( ...@@ -207,11 +209,11 @@ def shifted_window_attention(
attn = attn.view(-1, num_heads, x.size(1), x.size(1)) attn = attn.view(-1, num_heads, x.size(1), x.size(1))
attn = F.softmax(attn, dim=-1) attn = F.softmax(attn, dim=-1)
attn = F.dropout(attn, p=attention_dropout) attn = F.dropout(attn, p=attention_dropout, training=training)
x = attn.matmul(v).transpose(1, 2).reshape(x.size(0), x.size(1), C) x = attn.matmul(v).transpose(1, 2).reshape(x.size(0), x.size(1), C)
x = F.linear(x, proj_weight, proj_bias) x = F.linear(x, proj_weight, proj_bias)
x = F.dropout(x, p=dropout) x = F.dropout(x, p=dropout, training=training)
# reverse windows # reverse windows
x = x.view(B, pad_H // window_size[0], pad_W // window_size[1], window_size[0], window_size[1], C) x = x.view(B, pad_H // window_size[0], pad_W // window_size[1], window_size[0], window_size[1], C)
...@@ -286,7 +288,7 @@ class ShiftedWindowAttention(nn.Module): ...@@ -286,7 +288,7 @@ class ShiftedWindowAttention(nn.Module):
self.relative_position_bias_table, self.relative_position_index, self.window_size # type: ignore[arg-type] self.relative_position_bias_table, self.relative_position_index, self.window_size # type: ignore[arg-type]
) )
def forward(self, x: Tensor): def forward(self, x: Tensor) -> Tensor:
""" """
Args: Args:
x (Tensor): Tensor with layout of [B, H, W, C] x (Tensor): Tensor with layout of [B, H, W, C]
...@@ -306,6 +308,7 @@ class ShiftedWindowAttention(nn.Module): ...@@ -306,6 +308,7 @@ class ShiftedWindowAttention(nn.Module):
dropout=self.dropout, dropout=self.dropout,
qkv_bias=self.qkv.bias, qkv_bias=self.qkv.bias,
proj_bias=self.proj.bias, proj_bias=self.proj.bias,
training=self.training,
) )
...@@ -391,6 +394,7 @@ class ShiftedWindowAttentionV2(ShiftedWindowAttention): ...@@ -391,6 +394,7 @@ class ShiftedWindowAttentionV2(ShiftedWindowAttention):
qkv_bias=self.qkv.bias, qkv_bias=self.qkv.bias,
proj_bias=self.proj.bias, proj_bias=self.proj.bias,
logit_scale=self.logit_scale, logit_scale=self.logit_scale,
training=self.training,
) )
......
...@@ -124,6 +124,7 @@ def shifted_window_attention_3d( ...@@ -124,6 +124,7 @@ def shifted_window_attention_3d(
dropout: float = 0.0, dropout: float = 0.0,
qkv_bias: Optional[Tensor] = None, qkv_bias: Optional[Tensor] = None,
proj_bias: Optional[Tensor] = None, proj_bias: Optional[Tensor] = None,
training: bool = True,
) -> Tensor: ) -> Tensor:
""" """
Window based multi-head self attention (W-MSA) module with relative position bias. Window based multi-head self attention (W-MSA) module with relative position bias.
...@@ -140,6 +141,7 @@ def shifted_window_attention_3d( ...@@ -140,6 +141,7 @@ def shifted_window_attention_3d(
dropout (float): Dropout ratio of output. Default: 0.0. dropout (float): Dropout ratio of output. Default: 0.0.
qkv_bias (Tensor[out_dim], optional): The bias tensor of query, key, value. Default: None. qkv_bias (Tensor[out_dim], optional): The bias tensor of query, key, value. Default: None.
proj_bias (Tensor[out_dim], optional): The bias tensor of projection. Default: None. proj_bias (Tensor[out_dim], optional): The bias tensor of projection. Default: None.
training (bool, optional): Training flag used by the dropout parameters. Default: True.
Returns: Returns:
Tensor[B, T, H, W, C]: The output tensor after shifted window attention. Tensor[B, T, H, W, C]: The output tensor after shifted window attention.
""" """
...@@ -194,11 +196,11 @@ def shifted_window_attention_3d( ...@@ -194,11 +196,11 @@ def shifted_window_attention_3d(
attn = attn.view(-1, num_heads, x.size(1), x.size(1)) attn = attn.view(-1, num_heads, x.size(1), x.size(1))
attn = F.softmax(attn, dim=-1) attn = F.softmax(attn, dim=-1)
attn = F.dropout(attn, p=attention_dropout) attn = F.dropout(attn, p=attention_dropout, training=training)
x = attn.matmul(v).transpose(1, 2).reshape(x.size(0), x.size(1), c) x = attn.matmul(v).transpose(1, 2).reshape(x.size(0), x.size(1), c)
x = F.linear(x, proj_weight, proj_bias) x = F.linear(x, proj_weight, proj_bias)
x = F.dropout(x, p=dropout) x = F.dropout(x, p=dropout, training=training)
# reverse windows # reverse windows
x = x.view( x = x.view(
...@@ -310,6 +312,7 @@ class ShiftedWindowAttention3d(nn.Module): ...@@ -310,6 +312,7 @@ class ShiftedWindowAttention3d(nn.Module):
dropout=self.dropout, dropout=self.dropout,
qkv_bias=self.qkv.bias, qkv_bias=self.qkv.bias,
proj_bias=self.proj.bias, proj_bias=self.proj.bias,
training=self.training,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment