Unverified Commit 7320d95d authored by NielsRogge's avatar NielsRogge Committed by GitHub
Browse files

[Swin, Swinv2] Fix attn_mask dtype (#18803)



* Add dtype

* Fix Swinv2 as well
Co-authored-by: default avatarNiels Rogge <nielsrogge@Nielss-MacBook-Pro.local>
parent 5c702175
......@@ -538,10 +538,10 @@ class DonutSwinLayer(nn.Module):
self.shift_size = 0
self.window_size = min(input_resolution)
def get_attn_mask(self, height, width):
def get_attn_mask(self, height, width, dtype):
if self.shift_size > 0:
# calculate attention mask for SW-MSA
img_mask = torch.zeros((1, height, width, 1))
img_mask = torch.zeros((1, height, width, 1), dtype=dtype)
height_slices = (
slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
......@@ -600,7 +600,7 @@ class DonutSwinLayer(nn.Module):
# partition windows
hidden_states_windows = window_partition(shifted_hidden_states, self.window_size)
hidden_states_windows = hidden_states_windows.view(-1, self.window_size * self.window_size, channels)
attn_mask = self.get_attn_mask(height_pad, width_pad)
attn_mask = self.get_attn_mask(height_pad, width_pad, dtype=hidden_states.dtype)
if attn_mask is not None:
attn_mask = attn_mask.to(hidden_states_windows.device)
......
......@@ -604,10 +604,10 @@ class SwinLayer(nn.Module):
self.shift_size = 0
self.window_size = min(input_resolution)
def get_attn_mask(self, height, width):
def get_attn_mask(self, height, width, dtype):
if self.shift_size > 0:
# calculate attention mask for SW-MSA
img_mask = torch.zeros((1, height, width, 1))
img_mask = torch.zeros((1, height, width, 1), dtype=dtype)
height_slices = (
slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
......@@ -666,7 +666,7 @@ class SwinLayer(nn.Module):
# partition windows
hidden_states_windows = window_partition(shifted_hidden_states, self.window_size)
hidden_states_windows = hidden_states_windows.view(-1, self.window_size * self.window_size, channels)
attn_mask = self.get_attn_mask(height_pad, width_pad)
attn_mask = self.get_attn_mask(height_pad, width_pad, dtype=hidden_states.dtype)
if attn_mask is not None:
attn_mask = attn_mask.to(hidden_states_windows.device)
......
......@@ -676,10 +676,10 @@ class Swinv2Layer(nn.Module):
else target_shift_size[0]
)
def get_attn_mask(self, height, width):
def get_attn_mask(self, height, width, dtype):
if self.shift_size > 0:
# calculate attention mask for shifted window multihead self attention
img_mask = torch.zeros((1, height, width, 1))
img_mask = torch.zeros((1, height, width, 1), dtype=dtype)
height_slices = (
slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
......@@ -736,7 +736,7 @@ class Swinv2Layer(nn.Module):
# partition windows
hidden_states_windows = window_partition(shifted_hidden_states, self.window_size)
hidden_states_windows = hidden_states_windows.view(-1, self.window_size * self.window_size, channels)
attn_mask = self.get_attn_mask(height_pad, width_pad)
attn_mask = self.get_attn_mask(height_pad, width_pad, dtype=hidden_states.dtype)
if attn_mask is not None:
attn_mask = attn_mask.to(hidden_states_windows.device)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment