Unverified Commit 9556054f authored by Rahul Vinod Vishwakarma's avatar Rahul Vinod Vishwakarma Committed by GitHub
Browse files

OWL-ViT box_predictor inefficiency issue (#29712)



* Calculating box_bias at the start once, then reusing it at inference

* Updating the compute_box_bias function for backwards compatibility

* Caching compute_box_bias function

* Bux fix

* Update owlv2 accordingly to ensure repo consistency

* Co-authored by: nvbinh15 <binh.pdc01@gmail.com>

* Fixup changes

* Made copied code consistent

* Co-authored by: nvbinh15 <binh.pdc01@gmail.com>

---------

Co-authored-by: Nguyen Van Binh <>
Co-authored-by: default avatarNguyen Van Binh <binh.pdc01@gmail.com>
parent 0639034a
......@@ -16,9 +16,9 @@
import warnings
from dataclasses import dataclass
from functools import lru_cache
from typing import Any, Dict, Optional, Tuple, Union
import numpy as np
import torch
import torch.utils.checkpoint
from torch import Tensor, nn
......@@ -1312,27 +1312,22 @@ class Owlv2ForObjectDetection(Owlv2PreTrainedModel):
self.sigmoid = nn.Sigmoid()
self.sqrt_num_patches = config.vision_config.image_size // config.vision_config.patch_size
self.box_bias = self.compute_box_bias(self.sqrt_num_patches)
@staticmethod
# Copied from transformers.models.owlvit.modeling_owlvit.OwlViTForObjectDetection.normalize_grid_corner_coordinates
def normalize_grid_corner_coordinates(self, feature_map: torch.FloatTensor):
# Computes normalized xy corner coordinates from feature_map.
if not feature_map.ndim == 4:
raise ValueError("Expected input shape is [batch_size, num_patches, num_patches, hidden_dim]")
def normalize_grid_corner_coordinates(num_patches: int) -> torch.Tensor:
# Create grid coordinates using torch
x_coordinates = torch.arange(1, num_patches + 1, dtype=torch.float32)
y_coordinates = torch.arange(1, num_patches + 1, dtype=torch.float32)
xx, yy = torch.meshgrid(x_coordinates, y_coordinates, indexing="xy")
device = feature_map.device
num_patches = feature_map.shape[1]
# TODO: Remove numpy usage.
box_coordinates = np.stack(
np.meshgrid(np.arange(1, num_patches + 1), np.arange(1, num_patches + 1)), axis=-1
).astype(np.float32)
box_coordinates /= np.array([num_patches, num_patches], np.float32)
# Stack the coordinates and divide by num_patches
box_coordinates = torch.stack((xx, yy), dim=-1)
box_coordinates /= num_patches
# Flatten (h, w, 2) -> (h*w, 2)
box_coordinates = box_coordinates.reshape(
box_coordinates.shape[0] * box_coordinates.shape[1], box_coordinates.shape[2]
)
box_coordinates = torch.from_numpy(box_coordinates).to(device)
box_coordinates = box_coordinates.view(-1, 2)
return box_coordinates
......@@ -1350,17 +1345,20 @@ class Owlv2ForObjectDetection(Owlv2PreTrainedModel):
objectness_logits = objectness_logits[..., 0]
return objectness_logits
@lru_cache(maxsize=2)
# Copied from transformers.models.owlvit.modeling_owlvit.OwlViTForObjectDetection.compute_box_bias
def compute_box_bias(self, feature_map: torch.FloatTensor) -> torch.FloatTensor:
def compute_box_bias(self, num_patches: int, feature_map: Optional[torch.FloatTensor] = None) -> torch.Tensor:
if feature_map is not None:
raise ValueError("feature_map has been deprecated as an input. Please pass in num_patches instead")
# The box center is biased to its position on the feature grid
box_coordinates = self.normalize_grid_corner_coordinates(feature_map)
box_coordinates = self.normalize_grid_corner_coordinates(num_patches)
box_coordinates = torch.clip(box_coordinates, 0.0, 1.0)
# Unnormalize xy
box_coord_bias = torch.log(box_coordinates + 1e-4) - torch.log1p(-box_coordinates + 1e-4)
# The box size is biased to the patch size
box_size = torch.full_like(box_coord_bias, 1.0 / feature_map.shape[-2])
box_size = torch.full_like(box_coord_bias, 1.0 / num_patches)
box_size_bias = torch.log(box_size + 1e-4) - torch.log1p(-box_size + 1e-4)
# Compute box bias
......@@ -1387,7 +1385,8 @@ class Owlv2ForObjectDetection(Owlv2PreTrainedModel):
pred_boxes = self.box_head(image_feats)
# Compute the location of each token on the grid and use it to compute a bias for the bbox prediction
pred_boxes += self.compute_box_bias(feature_map)
box_bias = self.box_bias.to(feature_map.device)
pred_boxes += box_bias
pred_boxes = self.sigmoid(pred_boxes)
return pred_boxes
......
......@@ -16,9 +16,9 @@
import warnings
from dataclasses import dataclass
from functools import lru_cache
from typing import Any, Dict, Optional, Tuple, Union
import numpy as np
import torch
import torch.utils.checkpoint
from torch import Tensor, nn
......@@ -1293,39 +1293,37 @@ class OwlViTForObjectDetection(OwlViTPreTrainedModel):
self.sigmoid = nn.Sigmoid()
self.sqrt_num_patches = config.vision_config.image_size // config.vision_config.patch_size
self.box_bias = self.compute_box_bias(self.sqrt_num_patches)
def normalize_grid_corner_coordinates(self, feature_map: torch.FloatTensor):
# Computes normalized xy corner coordinates from feature_map.
if not feature_map.ndim == 4:
raise ValueError("Expected input shape is [batch_size, num_patches, num_patches, hidden_dim]")
@staticmethod
def normalize_grid_corner_coordinates(num_patches: int) -> torch.Tensor:
# Create grid coordinates using torch
x_coordinates = torch.arange(1, num_patches + 1, dtype=torch.float32)
y_coordinates = torch.arange(1, num_patches + 1, dtype=torch.float32)
xx, yy = torch.meshgrid(x_coordinates, y_coordinates, indexing="xy")
device = feature_map.device
num_patches = feature_map.shape[1]
# TODO: Remove numpy usage.
box_coordinates = np.stack(
np.meshgrid(np.arange(1, num_patches + 1), np.arange(1, num_patches + 1)), axis=-1
).astype(np.float32)
box_coordinates /= np.array([num_patches, num_patches], np.float32)
# Stack the coordinates and divide by num_patches
box_coordinates = torch.stack((xx, yy), dim=-1)
box_coordinates /= num_patches
# Flatten (h, w, 2) -> (h*w, 2)
box_coordinates = box_coordinates.reshape(
box_coordinates.shape[0] * box_coordinates.shape[1], box_coordinates.shape[2]
)
box_coordinates = torch.from_numpy(box_coordinates).to(device)
box_coordinates = box_coordinates.view(-1, 2)
return box_coordinates
def compute_box_bias(self, feature_map: torch.FloatTensor) -> torch.FloatTensor:
@lru_cache(maxsize=2)
def compute_box_bias(self, num_patches: int, feature_map: Optional[torch.FloatTensor] = None) -> torch.Tensor:
if feature_map is not None:
raise ValueError("feature_map has been deprecated as an input. Please pass in num_patches instead")
# The box center is biased to its position on the feature grid
box_coordinates = self.normalize_grid_corner_coordinates(feature_map)
box_coordinates = self.normalize_grid_corner_coordinates(num_patches)
box_coordinates = torch.clip(box_coordinates, 0.0, 1.0)
# Unnormalize xy
box_coord_bias = torch.log(box_coordinates + 1e-4) - torch.log1p(-box_coordinates + 1e-4)
# The box size is biased to the patch size
box_size = torch.full_like(box_coord_bias, 1.0 / feature_map.shape[-2])
box_size = torch.full_like(box_coord_bias, 1.0 / num_patches)
box_size_bias = torch.log(box_size + 1e-4) - torch.log1p(-box_size + 1e-4)
# Compute box bias
......@@ -1351,7 +1349,8 @@ class OwlViTForObjectDetection(OwlViTPreTrainedModel):
pred_boxes = self.box_head(image_feats)
# Compute the location of each token on the grid and use it to compute a bias for the bbox prediction
pred_boxes += self.compute_box_bias(feature_map)
box_bias = self.box_bias.to(feature_map.device)
pred_boxes += box_bias
pred_boxes = self.sigmoid(pred_boxes)
return pred_boxes
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment