Unverified Commit 85e9d644 authored by Shinji Yamada's avatar Shinji Yamada Committed by GitHub
Browse files

fix: when window_size is passes as array (#26800)

parent b3961f72
......@@ -791,6 +791,11 @@ class Swinv2Stage(nn.Module):
super().__init__()
self.config = config
self.dim = dim
window_size = (
config.window_size
if isinstance(config.window_size, collections.abc.Iterable)
else (config.window_size, config.window_size)
)
self.blocks = nn.ModuleList(
[
Swinv2Layer(
......@@ -798,7 +803,7 @@ class Swinv2Stage(nn.Module):
dim=dim,
input_resolution=input_resolution,
num_heads=num_heads,
shift_size=0 if (i % 2 == 0) else config.window_size // 2,
shift_size=[0, 0] if (i % 2 == 0) else [window_size[0] // 2, window_size[1] // 2],
pretrained_window_size=pretrained_window_size,
)
for i in range(depth)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment