"vscode:/vscode.git/clone" did not exist on "c3d6f33918bd67fd624865c3431dc2977d15450a"
Unverified Commit 847e5691 authored by Dennis Loevlie's avatar Dennis Loevlie Committed by GitHub
Browse files

Fix: Change tensors to integers for torch.dynamo and torch.compile compatibility (#23475)

* Fix: Change tensors to integers in torch.split() for torch.dynamo and torch.compile compatibility

* Applied the suggested fix to the utils/check_copies.py test

* Applied the suggested fix by changing the original function that gets copied
parent 389bdba6
......@@ -550,7 +550,7 @@ def multi_scale_deformable_attention(
) -> Tensor:
batch_size, _, num_heads, hidden_dim = value.shape
_, num_queries, num_heads, num_levels, num_points, _ = sampling_locations.shape
value_list = value.split([height * width for height, width in value_spatial_shapes], dim=1)
value_list = value.split([height.item() * width.item() for height, width in value_spatial_shapes], dim=1)
sampling_grids = 2 * sampling_locations - 1
sampling_value_list = []
for level_id, (height, width) in enumerate(value_spatial_shapes):
......
......@@ -453,7 +453,7 @@ def multi_scale_deformable_attention(
) -> Tensor:
batch_size, _, num_heads, hidden_dim = value.shape
_, num_queries, num_heads, num_levels, num_points, _ = sampling_locations.shape
value_list = value.split([height * width for height, width in value_spatial_shapes], dim=1)
value_list = value.split([height.item() * width.item() for height, width in value_spatial_shapes], dim=1)
sampling_grids = 2 * sampling_locations - 1
sampling_value_list = []
for level_id, (height, width) in enumerate(value_spatial_shapes):
......
......@@ -810,7 +810,7 @@ def multi_scale_deformable_attention(
) -> Tensor:
batch_size, _, num_heads, hidden_dim = value.shape
_, num_queries, num_heads, num_levels, num_points, _ = sampling_locations.shape
value_list = value.split([height * width for height, width in value_spatial_shapes], dim=1)
value_list = value.split([height.item() * width.item() for height, width in value_spatial_shapes], dim=1)
sampling_grids = 2 * sampling_locations - 1
sampling_value_list = []
for level_id, (height, width) in enumerate(value_spatial_shapes):
......@@ -1340,7 +1340,7 @@ class Mask2FormerPixelDecoder(nn.Module):
else:
split_sizes[i] = last_hidden_state.shape[1] - level_start_index[i]
encoder_output = torch.split(last_hidden_state, split_sizes, dim=1)
encoder_output = torch.split(last_hidden_state, [size.item() for size in split_sizes], dim=1)
# Compute final features
outputs = [
......
......@@ -66,7 +66,7 @@ def multi_scale_deformable_attention(
) -> Tensor:
batch_size, _, num_heads, hidden_dim = value.shape
_, num_queries, num_heads, num_levels, num_points, _ = sampling_locations.shape
value_list = value.split([height * width for height, width in value_spatial_shapes], dim=1)
value_list = value.split([height.item() * width.item() for height, width in value_spatial_shapes], dim=1)
sampling_grids = 2 * sampling_locations - 1
sampling_value_list = []
for level_id, (height, width) in enumerate(value_spatial_shapes):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment