Unverified Commit 7320d95d authored by NielsRogge's avatar NielsRogge Committed by GitHub
Browse files

[Swin, Swinv2] Fix attn_mask dtype (#18803)



* Add dtype

* Fix Swinv2 as well
Co-authored-by: default avatarNiels Rogge <nielsrogge@Nielss-MacBook-Pro.local>
parent 5c702175
...@@ -538,10 +538,10 @@ class DonutSwinLayer(nn.Module): ...@@ -538,10 +538,10 @@ class DonutSwinLayer(nn.Module):
self.shift_size = 0 self.shift_size = 0
self.window_size = min(input_resolution) self.window_size = min(input_resolution)
def get_attn_mask(self, height, width): def get_attn_mask(self, height, width, dtype):
if self.shift_size > 0: if self.shift_size > 0:
# calculate attention mask for SW-MSA # calculate attention mask for SW-MSA
img_mask = torch.zeros((1, height, width, 1)) img_mask = torch.zeros((1, height, width, 1), dtype=dtype)
height_slices = ( height_slices = (
slice(0, -self.window_size), slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size), slice(-self.window_size, -self.shift_size),
...@@ -600,7 +600,7 @@ class DonutSwinLayer(nn.Module): ...@@ -600,7 +600,7 @@ class DonutSwinLayer(nn.Module):
# partition windows # partition windows
hidden_states_windows = window_partition(shifted_hidden_states, self.window_size) hidden_states_windows = window_partition(shifted_hidden_states, self.window_size)
hidden_states_windows = hidden_states_windows.view(-1, self.window_size * self.window_size, channels) hidden_states_windows = hidden_states_windows.view(-1, self.window_size * self.window_size, channels)
attn_mask = self.get_attn_mask(height_pad, width_pad) attn_mask = self.get_attn_mask(height_pad, width_pad, dtype=hidden_states.dtype)
if attn_mask is not None: if attn_mask is not None:
attn_mask = attn_mask.to(hidden_states_windows.device) attn_mask = attn_mask.to(hidden_states_windows.device)
......
...@@ -604,10 +604,10 @@ class SwinLayer(nn.Module): ...@@ -604,10 +604,10 @@ class SwinLayer(nn.Module):
self.shift_size = 0 self.shift_size = 0
self.window_size = min(input_resolution) self.window_size = min(input_resolution)
def get_attn_mask(self, height, width): def get_attn_mask(self, height, width, dtype):
if self.shift_size > 0: if self.shift_size > 0:
# calculate attention mask for SW-MSA # calculate attention mask for SW-MSA
img_mask = torch.zeros((1, height, width, 1)) img_mask = torch.zeros((1, height, width, 1), dtype=dtype)
height_slices = ( height_slices = (
slice(0, -self.window_size), slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size), slice(-self.window_size, -self.shift_size),
...@@ -666,7 +666,7 @@ class SwinLayer(nn.Module): ...@@ -666,7 +666,7 @@ class SwinLayer(nn.Module):
# partition windows # partition windows
hidden_states_windows = window_partition(shifted_hidden_states, self.window_size) hidden_states_windows = window_partition(shifted_hidden_states, self.window_size)
hidden_states_windows = hidden_states_windows.view(-1, self.window_size * self.window_size, channels) hidden_states_windows = hidden_states_windows.view(-1, self.window_size * self.window_size, channels)
attn_mask = self.get_attn_mask(height_pad, width_pad) attn_mask = self.get_attn_mask(height_pad, width_pad, dtype=hidden_states.dtype)
if attn_mask is not None: if attn_mask is not None:
attn_mask = attn_mask.to(hidden_states_windows.device) attn_mask = attn_mask.to(hidden_states_windows.device)
......
...@@ -676,10 +676,10 @@ class Swinv2Layer(nn.Module): ...@@ -676,10 +676,10 @@ class Swinv2Layer(nn.Module):
else target_shift_size[0] else target_shift_size[0]
) )
def get_attn_mask(self, height, width): def get_attn_mask(self, height, width, dtype):
if self.shift_size > 0: if self.shift_size > 0:
# calculate attention mask for shifted window multihead self attention # calculate attention mask for shifted window multihead self attention
img_mask = torch.zeros((1, height, width, 1)) img_mask = torch.zeros((1, height, width, 1), dtype=dtype)
height_slices = ( height_slices = (
slice(0, -self.window_size), slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size), slice(-self.window_size, -self.shift_size),
...@@ -736,7 +736,7 @@ class Swinv2Layer(nn.Module): ...@@ -736,7 +736,7 @@ class Swinv2Layer(nn.Module):
# partition windows # partition windows
hidden_states_windows = window_partition(shifted_hidden_states, self.window_size) hidden_states_windows = window_partition(shifted_hidden_states, self.window_size)
hidden_states_windows = hidden_states_windows.view(-1, self.window_size * self.window_size, channels) hidden_states_windows = hidden_states_windows.view(-1, self.window_size * self.window_size, channels)
attn_mask = self.get_attn_mask(height_pad, width_pad) attn_mask = self.get_attn_mask(height_pad, width_pad, dtype=hidden_states.dtype)
if attn_mask is not None: if attn_mask is not None:
attn_mask = attn_mask.to(hidden_states_windows.device) attn_mask = attn_mask.to(hidden_states_windows.device)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment