Unverified Commit ef38e2a7 authored by fxmarty's avatar fxmarty Committed by GitHub
Browse files

Make vitdet jit trace complient (#30065)

* remove controlflows

* style

* rename patch_ to padded_ following review comment

* style
parent a71def02
...@@ -94,11 +94,12 @@ class VitDetEmbeddings(nn.Module): ...@@ -94,11 +94,12 @@ class VitDetEmbeddings(nn.Module):
if has_cls_token: if has_cls_token:
abs_pos_embeddings = abs_pos_embeddings[:, 1:] abs_pos_embeddings = abs_pos_embeddings[:, 1:]
num_position = abs_pos_embeddings.shape[1] num_position = abs_pos_embeddings.shape[1]
size = int(math.sqrt(num_position)) size = int(math.sqrt(num_position)) # This is a constant and can be recorded as such in the ONNX export.
if size * size != num_position: if size * size != num_position:
raise ValueError("Absolute position embeddings must be a square number.") raise ValueError("Absolute position embeddings must be a square number.")
if size != height or size != width: if torch.jit.is_tracing() or (size != height or size != width):
# nn.functional.interpolate is a noop in case size == height and size == width - we need to always capture this path with jit.trace.
new_abs_pos_embeddings = nn.functional.interpolate( new_abs_pos_embeddings = nn.functional.interpolate(
abs_pos_embeddings.reshape(1, size, size, -1).permute(0, 3, 1, 2), abs_pos_embeddings.reshape(1, size, size, -1).permute(0, 3, 1, 2),
size=(height, width), size=(height, width),
...@@ -132,6 +133,7 @@ class VitDetEmbeddings(nn.Module): ...@@ -132,6 +133,7 @@ class VitDetEmbeddings(nn.Module):
return embeddings return embeddings
@torch.jit.script_if_tracing # nn.functional.interpolate's `size` needs to be dynamic.
def get_rel_pos(q_size, k_size, rel_pos): def get_rel_pos(q_size, k_size, rel_pos):
""" """
Get relative positional embeddings according to the relative positions of query and key sizes. Get relative positional embeddings according to the relative positions of query and key sizes.
...@@ -399,21 +401,23 @@ def window_partition(hidden_state, window_size): ...@@ -399,21 +401,23 @@ def window_partition(hidden_state, window_size):
Returns: Returns:
`tuple(torch.FloatTensor)` comprising various elements: `tuple(torch.FloatTensor)` comprising various elements:
- windows: windows after partition with [batch_size * num_windows, window_size, window_size, num_channels]. - windows: windows after partition with [batch_size * num_windows, window_size, window_size, num_channels].
- (patch_height, patch_width): padded height and width before partition - (padded_height, padded_width): padded height and width before partition
""" """
batch_size, height, width, num_channels = hidden_state.shape batch_size, height, width, num_channels = hidden_state.shape
pad_height = (window_size - height % window_size) % window_size pad_height = (window_size - height % window_size) % window_size
pad_width = (window_size - width % window_size) % window_size pad_width = (window_size - width % window_size) % window_size
if pad_height > 0 or pad_width > 0:
hidden_state = nn.functional.pad(hidden_state, (0, 0, 0, pad_width, 0, pad_height)) # Noop in case pad_width == 0 and pad_height == 0.
patch_height, patch_width = height + pad_height, width + pad_width hidden_state = nn.functional.pad(hidden_state, (0, 0, 0, pad_width, 0, pad_height))
padded_height, padded_width = height + pad_height, width + pad_width
hidden_state = hidden_state.view( hidden_state = hidden_state.view(
batch_size, patch_height // window_size, window_size, patch_width // window_size, window_size, num_channels batch_size, padded_height // window_size, window_size, padded_width // window_size, window_size, num_channels
) )
windows = hidden_state.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, num_channels) windows = hidden_state.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, num_channels)
return windows, (patch_height, patch_width) return windows, (padded_height, padded_width)
def window_unpartition(windows, window_size, pad_height_width, height_width): def window_unpartition(windows, window_size, pad_height_width, height_width):
...@@ -426,23 +430,24 @@ def window_unpartition(windows, window_size, pad_height_width, height_width): ...@@ -426,23 +430,24 @@ def window_unpartition(windows, window_size, pad_height_width, height_width):
window_size (`int`): window_size (`int`):
Window size. Window size.
pad_height_width (`Tuple[int]`): pad_height_width (`Tuple[int]`):
Padded height and width (patch_height, patch_width). Padded height and width (padded_height, padded_width).
height_width (`Tuple[int]`): height_width (`Tuple[int]`):
Original height and width before padding. Original height and width before padding.
Returns: Returns:
hidden_state: unpartitioned sequences with [batch_size, height, width, num_channels]. hidden_state: unpartitioned sequences with [batch_size, height, width, num_channels].
""" """
patch_height, patch_width = pad_height_width padded_height, padded_width = pad_height_width
height, width = height_width height, width = height_width
batch_size = windows.shape[0] // (patch_height * patch_width // window_size // window_size) batch_size = windows.shape[0] // (padded_height * padded_width // window_size // window_size)
hidden_state = windows.view( hidden_state = windows.view(
batch_size, patch_height // window_size, patch_width // window_size, window_size, window_size, -1 batch_size, padded_height // window_size, padded_width // window_size, window_size, window_size, -1
) )
hidden_state = hidden_state.permute(0, 1, 3, 2, 4, 5).contiguous().view(batch_size, patch_height, patch_width, -1) hidden_state = hidden_state.permute(0, 1, 3, 2, 4, 5).contiguous()
hidden_state = hidden_state.view(batch_size, padded_height, padded_width, -1)
if patch_height > height or patch_width > width: # We always have height <= padded_height and width <= padded_width
hidden_state = hidden_state[:, :height, :width, :].contiguous() hidden_state = hidden_state[:, :height, :width, :].contiguous()
return hidden_state return hidden_state
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment