Unverified Commit ef38e2a7 authored by fxmarty's avatar fxmarty Committed by GitHub
Browse files

Make vitdet jit trace complient (#30065)

* remove controlflows

* style

* rename patch_ to padded_ following review comment

* style
parent a71def02
......@@ -94,11 +94,12 @@ class VitDetEmbeddings(nn.Module):
if has_cls_token:
abs_pos_embeddings = abs_pos_embeddings[:, 1:]
num_position = abs_pos_embeddings.shape[1]
size = int(math.sqrt(num_position))
size = int(math.sqrt(num_position)) # This is a constant and can be recorded as such in the ONNX export.
if size * size != num_position:
raise ValueError("Absolute position embeddings must be a square number.")
if size != height or size != width:
if torch.jit.is_tracing() or (size != height or size != width):
# nn.functional.interpolate is a noop in case size == height and size == width - we need to always capture this path with jit.trace.
new_abs_pos_embeddings = nn.functional.interpolate(
abs_pos_embeddings.reshape(1, size, size, -1).permute(0, 3, 1, 2),
size=(height, width),
......@@ -132,6 +133,7 @@ class VitDetEmbeddings(nn.Module):
return embeddings
@torch.jit.script_if_tracing # nn.functional.interpolate's `size` needs to be dynamic.
def get_rel_pos(q_size, k_size, rel_pos):
"""
Get relative positional embeddings according to the relative positions of query and key sizes.
......@@ -399,21 +401,23 @@ def window_partition(hidden_state, window_size):
Returns:
`tuple(torch.FloatTensor)` comprising various elements:
- windows: windows after partition with [batch_size * num_windows, window_size, window_size, num_channels].
- (patch_height, patch_width): padded height and width before partition
- (padded_height, padded_width): padded height and width before partition
"""
batch_size, height, width, num_channels = hidden_state.shape
pad_height = (window_size - height % window_size) % window_size
pad_width = (window_size - width % window_size) % window_size
if pad_height > 0 or pad_width > 0:
# Noop in case pad_width == 0 and pad_height == 0.
hidden_state = nn.functional.pad(hidden_state, (0, 0, 0, pad_width, 0, pad_height))
patch_height, patch_width = height + pad_height, width + pad_width
padded_height, padded_width = height + pad_height, width + pad_width
hidden_state = hidden_state.view(
batch_size, patch_height // window_size, window_size, patch_width // window_size, window_size, num_channels
batch_size, padded_height // window_size, window_size, padded_width // window_size, window_size, num_channels
)
windows = hidden_state.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, num_channels)
return windows, (patch_height, patch_width)
return windows, (padded_height, padded_width)
def window_unpartition(windows, window_size, pad_height_width, height_width):
......@@ -426,22 +430,23 @@ def window_unpartition(windows, window_size, pad_height_width, height_width):
window_size (`int`):
Window size.
pad_height_width (`Tuple[int]`):
Padded height and width (patch_height, patch_width).
Padded height and width (padded_height, padded_width).
height_width (`Tuple[int]`):
Original height and width before padding.
Returns:
hidden_state: unpartitioned sequences with [batch_size, height, width, num_channels].
"""
patch_height, patch_width = pad_height_width
padded_height, padded_width = pad_height_width
height, width = height_width
batch_size = windows.shape[0] // (patch_height * patch_width // window_size // window_size)
batch_size = windows.shape[0] // (padded_height * padded_width // window_size // window_size)
hidden_state = windows.view(
batch_size, patch_height // window_size, patch_width // window_size, window_size, window_size, -1
batch_size, padded_height // window_size, padded_width // window_size, window_size, window_size, -1
)
hidden_state = hidden_state.permute(0, 1, 3, 2, 4, 5).contiguous().view(batch_size, patch_height, patch_width, -1)
hidden_state = hidden_state.permute(0, 1, 3, 2, 4, 5).contiguous()
hidden_state = hidden_state.view(batch_size, padded_height, padded_width, -1)
if patch_height > height or patch_width > width:
# We always have height <= padded_height and width <= padded_width
hidden_state = hidden_state[:, :height, :width, :].contiguous()
return hidden_state
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment