Unverified Commit b651efe5 authored by lewtun's avatar lewtun Committed by GitHub
Browse files

[Swin] Replace hard-coded batch size to enable dynamic ONNX export (#19475)

* [Swin] Replace hard-coded batch size to enable dynamic ONNX export
parent 440bbd44
...@@ -144,9 +144,9 @@ def window_reverse(windows, window_size, height, width): ...@@ -144,9 +144,9 @@ def window_reverse(windows, window_size, height, width):
""" """
Merges windows to produce higher resolution features. Merges windows to produce higher resolution features.
""" """
batch_size = math.floor(windows.shape[0] / (height * width / window_size / window_size)) num_channels = windows.shape[-1]
windows = windows.view(batch_size, height // window_size, width // window_size, window_size, window_size, -1) windows = windows.view(-1, height // window_size, width // window_size, window_size, window_size, num_channels)
windows = windows.permute(0, 1, 3, 2, 4, 5).contiguous().view(batch_size, height, width, -1) windows = windows.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, height, width, num_channels)
return windows return windows
......
...@@ -490,9 +490,9 @@ def window_reverse(windows, window_size, height, width): ...@@ -490,9 +490,9 @@ def window_reverse(windows, window_size, height, width):
""" """
Merges windows to produce higher resolution features. Merges windows to produce higher resolution features.
""" """
batch_size = math.floor(windows.shape[0] / (height * width / window_size / window_size)) num_channels = windows.shape[-1]
windows = windows.view(batch_size, height // window_size, width // window_size, window_size, window_size, -1) windows = windows.view(-1, height // window_size, width // window_size, window_size, window_size, num_channels)
windows = windows.permute(0, 1, 3, 2, 4, 5).contiguous().view(batch_size, height, width, -1) windows = windows.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, height, width, num_channels)
return windows return windows
......
...@@ -219,9 +219,9 @@ def window_reverse(windows, window_size, height, width): ...@@ -219,9 +219,9 @@ def window_reverse(windows, window_size, height, width):
""" """
Merges windows to produce higher resolution features. Merges windows to produce higher resolution features.
""" """
batch_size = math.floor(windows.shape[0] / (height * width / window_size / window_size)) num_channels = windows.shape[-1]
windows = windows.view(batch_size, height // window_size, width // window_size, window_size, window_size, -1) windows = windows.view(-1, height // window_size, width // window_size, window_size, window_size, num_channels)
windows = windows.permute(0, 1, 3, 2, 4, 5).contiguous().view(batch_size, height, width, -1) windows = windows.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, height, width, num_channels)
return windows return windows
......
...@@ -227,9 +227,9 @@ def window_reverse(windows, window_size, height, width): ...@@ -227,9 +227,9 @@ def window_reverse(windows, window_size, height, width):
""" """
Merges windows to produce higher resolution features. Merges windows to produce higher resolution features.
""" """
batch_size = math.floor(windows.shape[0] / (height * width / window_size / window_size)) num_channels = windows.shape[-1]
windows = windows.view(batch_size, height // window_size, width // window_size, window_size, window_size, -1) windows = windows.view(-1, height // window_size, width // window_size, window_size, window_size, num_channels)
windows = windows.permute(0, 1, 3, 2, 4, 5).contiguous().view(batch_size, height, width, -1) windows = windows.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, height, width, num_channels)
return windows return windows
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment