Unverified Commit 847e5691 authored by Dennis Loevlie's avatar Dennis Loevlie Committed by GitHub
Browse files

Fix: Change tensors to integers for torch.dynamo and torch.compile compatibility (#23475)

* Fix: Change tensors to integers in torch.split() for torch.dynamo and torch.compile compatibility

* Applied the suggested fix to the utils/check_copies.py test

* Applied the suggested fix by changing the original function that gets copied
parent 389bdba6
...@@ -550,7 +550,7 @@ def multi_scale_deformable_attention( ...@@ -550,7 +550,7 @@ def multi_scale_deformable_attention(
) -> Tensor: ) -> Tensor:
batch_size, _, num_heads, hidden_dim = value.shape batch_size, _, num_heads, hidden_dim = value.shape
_, num_queries, num_heads, num_levels, num_points, _ = sampling_locations.shape _, num_queries, num_heads, num_levels, num_points, _ = sampling_locations.shape
value_list = value.split([height * width for height, width in value_spatial_shapes], dim=1) value_list = value.split([height.item() * width.item() for height, width in value_spatial_shapes], dim=1)
sampling_grids = 2 * sampling_locations - 1 sampling_grids = 2 * sampling_locations - 1
sampling_value_list = [] sampling_value_list = []
for level_id, (height, width) in enumerate(value_spatial_shapes): for level_id, (height, width) in enumerate(value_spatial_shapes):
......
...@@ -453,7 +453,7 @@ def multi_scale_deformable_attention( ...@@ -453,7 +453,7 @@ def multi_scale_deformable_attention(
) -> Tensor: ) -> Tensor:
batch_size, _, num_heads, hidden_dim = value.shape batch_size, _, num_heads, hidden_dim = value.shape
_, num_queries, num_heads, num_levels, num_points, _ = sampling_locations.shape _, num_queries, num_heads, num_levels, num_points, _ = sampling_locations.shape
value_list = value.split([height * width for height, width in value_spatial_shapes], dim=1) value_list = value.split([height.item() * width.item() for height, width in value_spatial_shapes], dim=1)
sampling_grids = 2 * sampling_locations - 1 sampling_grids = 2 * sampling_locations - 1
sampling_value_list = [] sampling_value_list = []
for level_id, (height, width) in enumerate(value_spatial_shapes): for level_id, (height, width) in enumerate(value_spatial_shapes):
......
...@@ -810,7 +810,7 @@ def multi_scale_deformable_attention( ...@@ -810,7 +810,7 @@ def multi_scale_deformable_attention(
) -> Tensor: ) -> Tensor:
batch_size, _, num_heads, hidden_dim = value.shape batch_size, _, num_heads, hidden_dim = value.shape
_, num_queries, num_heads, num_levels, num_points, _ = sampling_locations.shape _, num_queries, num_heads, num_levels, num_points, _ = sampling_locations.shape
value_list = value.split([height * width for height, width in value_spatial_shapes], dim=1) value_list = value.split([height.item() * width.item() for height, width in value_spatial_shapes], dim=1)
sampling_grids = 2 * sampling_locations - 1 sampling_grids = 2 * sampling_locations - 1
sampling_value_list = [] sampling_value_list = []
for level_id, (height, width) in enumerate(value_spatial_shapes): for level_id, (height, width) in enumerate(value_spatial_shapes):
...@@ -1340,7 +1340,7 @@ class Mask2FormerPixelDecoder(nn.Module): ...@@ -1340,7 +1340,7 @@ class Mask2FormerPixelDecoder(nn.Module):
else: else:
split_sizes[i] = last_hidden_state.shape[1] - level_start_index[i] split_sizes[i] = last_hidden_state.shape[1] - level_start_index[i]
encoder_output = torch.split(last_hidden_state, split_sizes, dim=1) encoder_output = torch.split(last_hidden_state, [size.item() for size in split_sizes], dim=1)
# Compute final features # Compute final features
outputs = [ outputs = [
......
...@@ -66,7 +66,7 @@ def multi_scale_deformable_attention( ...@@ -66,7 +66,7 @@ def multi_scale_deformable_attention(
) -> Tensor: ) -> Tensor:
batch_size, _, num_heads, hidden_dim = value.shape batch_size, _, num_heads, hidden_dim = value.shape
_, num_queries, num_heads, num_levels, num_points, _ = sampling_locations.shape _, num_queries, num_heads, num_levels, num_points, _ = sampling_locations.shape
value_list = value.split([height * width for height, width in value_spatial_shapes], dim=1) value_list = value.split([height.item() * width.item() for height, width in value_spatial_shapes], dim=1)
sampling_grids = 2 * sampling_locations - 1 sampling_grids = 2 * sampling_locations - 1
sampling_value_list = [] sampling_value_list = []
for level_id, (height, width) in enumerate(value_spatial_shapes): for level_id, (height, width) in enumerate(value_spatial_shapes):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment