Unverified Commit 85e9d644 authored by Shinji Yamada's avatar Shinji Yamada Committed by GitHub
Browse files

fix: when window_size is passes as array (#26800)

parent b3961f72
...@@ -791,6 +791,11 @@ class Swinv2Stage(nn.Module): ...@@ -791,6 +791,11 @@ class Swinv2Stage(nn.Module):
super().__init__() super().__init__()
self.config = config self.config = config
self.dim = dim self.dim = dim
window_size = (
config.window_size
if isinstance(config.window_size, collections.abc.Iterable)
else (config.window_size, config.window_size)
)
self.blocks = nn.ModuleList( self.blocks = nn.ModuleList(
[ [
Swinv2Layer( Swinv2Layer(
...@@ -798,7 +803,7 @@ class Swinv2Stage(nn.Module): ...@@ -798,7 +803,7 @@ class Swinv2Stage(nn.Module):
dim=dim, dim=dim,
input_resolution=input_resolution, input_resolution=input_resolution,
num_heads=num_heads, num_heads=num_heads,
shift_size=0 if (i % 2 == 0) else config.window_size // 2, shift_size=[0, 0] if (i % 2 == 0) else [window_size[0] // 2, window_size[1] // 2],
pretrained_window_size=pretrained_window_size, pretrained_window_size=pretrained_window_size,
) )
for i in range(depth) for i in range(depth)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment