"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "d80c9a349709b3db888b3976b660ef4ea2e29646"
Unverified Commit 9556054f authored by Rahul Vinod Vishwakarma's avatar Rahul Vinod Vishwakarma Committed by GitHub
Browse files

OWL-ViT box_predictor inefficiency issue (#29712)



* Calculating box_bias at the start once, then reusing it at inference

* Updating the compute_box_bias function for backwards compatibility

* Caching compute_box_bias function

* Bux fix

* Update owlv2 accordingly to ensure repo consistency

* Co-authored by: nvbinh15 <binh.pdc01@gmail.com>

* Fixup changes

* Made copied code consistent

* Co-authored by: nvbinh15 <binh.pdc01@gmail.com>

---------

Co-authored-by: Nguyen Van Binh <>
Co-authored-by: default avatarNguyen Van Binh <binh.pdc01@gmail.com>
parent 0639034a
...@@ -16,9 +16,9 @@ ...@@ -16,9 +16,9 @@
import warnings import warnings
from dataclasses import dataclass from dataclasses import dataclass
from functools import lru_cache
from typing import Any, Dict, Optional, Tuple, Union from typing import Any, Dict, Optional, Tuple, Union
import numpy as np
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
from torch import Tensor, nn from torch import Tensor, nn
...@@ -1312,27 +1312,22 @@ class Owlv2ForObjectDetection(Owlv2PreTrainedModel): ...@@ -1312,27 +1312,22 @@ class Owlv2ForObjectDetection(Owlv2PreTrainedModel):
self.sigmoid = nn.Sigmoid() self.sigmoid = nn.Sigmoid()
self.sqrt_num_patches = config.vision_config.image_size // config.vision_config.patch_size self.sqrt_num_patches = config.vision_config.image_size // config.vision_config.patch_size
self.box_bias = self.compute_box_bias(self.sqrt_num_patches)
@staticmethod
# Copied from transformers.models.owlvit.modeling_owlvit.OwlViTForObjectDetection.normalize_grid_corner_coordinates # Copied from transformers.models.owlvit.modeling_owlvit.OwlViTForObjectDetection.normalize_grid_corner_coordinates
def normalize_grid_corner_coordinates(self, feature_map: torch.FloatTensor): def normalize_grid_corner_coordinates(num_patches: int) -> torch.Tensor:
# Computes normalized xy corner coordinates from feature_map. # Create grid coordinates using torch
if not feature_map.ndim == 4: x_coordinates = torch.arange(1, num_patches + 1, dtype=torch.float32)
raise ValueError("Expected input shape is [batch_size, num_patches, num_patches, hidden_dim]") y_coordinates = torch.arange(1, num_patches + 1, dtype=torch.float32)
xx, yy = torch.meshgrid(x_coordinates, y_coordinates, indexing="xy")
device = feature_map.device # Stack the coordinates and divide by num_patches
num_patches = feature_map.shape[1] box_coordinates = torch.stack((xx, yy), dim=-1)
box_coordinates /= num_patches
# TODO: Remove numpy usage.
box_coordinates = np.stack(
np.meshgrid(np.arange(1, num_patches + 1), np.arange(1, num_patches + 1)), axis=-1
).astype(np.float32)
box_coordinates /= np.array([num_patches, num_patches], np.float32)
# Flatten (h, w, 2) -> (h*w, 2) # Flatten (h, w, 2) -> (h*w, 2)
box_coordinates = box_coordinates.reshape( box_coordinates = box_coordinates.view(-1, 2)
box_coordinates.shape[0] * box_coordinates.shape[1], box_coordinates.shape[2]
)
box_coordinates = torch.from_numpy(box_coordinates).to(device)
return box_coordinates return box_coordinates
...@@ -1350,17 +1345,20 @@ class Owlv2ForObjectDetection(Owlv2PreTrainedModel): ...@@ -1350,17 +1345,20 @@ class Owlv2ForObjectDetection(Owlv2PreTrainedModel):
objectness_logits = objectness_logits[..., 0] objectness_logits = objectness_logits[..., 0]
return objectness_logits return objectness_logits
@lru_cache(maxsize=2)
# Copied from transformers.models.owlvit.modeling_owlvit.OwlViTForObjectDetection.compute_box_bias # Copied from transformers.models.owlvit.modeling_owlvit.OwlViTForObjectDetection.compute_box_bias
def compute_box_bias(self, feature_map: torch.FloatTensor) -> torch.FloatTensor: def compute_box_bias(self, num_patches: int, feature_map: Optional[torch.FloatTensor] = None) -> torch.Tensor:
if feature_map is not None:
raise ValueError("feature_map has been deprecated as an input. Please pass in num_patches instead")
# The box center is biased to its position on the feature grid # The box center is biased to its position on the feature grid
box_coordinates = self.normalize_grid_corner_coordinates(feature_map) box_coordinates = self.normalize_grid_corner_coordinates(num_patches)
box_coordinates = torch.clip(box_coordinates, 0.0, 1.0) box_coordinates = torch.clip(box_coordinates, 0.0, 1.0)
# Unnormalize xy # Unnormalize xy
box_coord_bias = torch.log(box_coordinates + 1e-4) - torch.log1p(-box_coordinates + 1e-4) box_coord_bias = torch.log(box_coordinates + 1e-4) - torch.log1p(-box_coordinates + 1e-4)
# The box size is biased to the patch size # The box size is biased to the patch size
box_size = torch.full_like(box_coord_bias, 1.0 / feature_map.shape[-2]) box_size = torch.full_like(box_coord_bias, 1.0 / num_patches)
box_size_bias = torch.log(box_size + 1e-4) - torch.log1p(-box_size + 1e-4) box_size_bias = torch.log(box_size + 1e-4) - torch.log1p(-box_size + 1e-4)
# Compute box bias # Compute box bias
...@@ -1387,7 +1385,8 @@ class Owlv2ForObjectDetection(Owlv2PreTrainedModel): ...@@ -1387,7 +1385,8 @@ class Owlv2ForObjectDetection(Owlv2PreTrainedModel):
pred_boxes = self.box_head(image_feats) pred_boxes = self.box_head(image_feats)
# Compute the location of each token on the grid and use it to compute a bias for the bbox prediction # Compute the location of each token on the grid and use it to compute a bias for the bbox prediction
pred_boxes += self.compute_box_bias(feature_map) box_bias = self.box_bias.to(feature_map.device)
pred_boxes += box_bias
pred_boxes = self.sigmoid(pred_boxes) pred_boxes = self.sigmoid(pred_boxes)
return pred_boxes return pred_boxes
......
...@@ -16,9 +16,9 @@ ...@@ -16,9 +16,9 @@
import warnings import warnings
from dataclasses import dataclass from dataclasses import dataclass
from functools import lru_cache
from typing import Any, Dict, Optional, Tuple, Union from typing import Any, Dict, Optional, Tuple, Union
import numpy as np
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
from torch import Tensor, nn from torch import Tensor, nn
...@@ -1293,39 +1293,37 @@ class OwlViTForObjectDetection(OwlViTPreTrainedModel): ...@@ -1293,39 +1293,37 @@ class OwlViTForObjectDetection(OwlViTPreTrainedModel):
self.sigmoid = nn.Sigmoid() self.sigmoid = nn.Sigmoid()
self.sqrt_num_patches = config.vision_config.image_size // config.vision_config.patch_size self.sqrt_num_patches = config.vision_config.image_size // config.vision_config.patch_size
self.box_bias = self.compute_box_bias(self.sqrt_num_patches)
def normalize_grid_corner_coordinates(self, feature_map: torch.FloatTensor): @staticmethod
# Computes normalized xy corner coordinates from feature_map. def normalize_grid_corner_coordinates(num_patches: int) -> torch.Tensor:
if not feature_map.ndim == 4: # Create grid coordinates using torch
raise ValueError("Expected input shape is [batch_size, num_patches, num_patches, hidden_dim]") x_coordinates = torch.arange(1, num_patches + 1, dtype=torch.float32)
y_coordinates = torch.arange(1, num_patches + 1, dtype=torch.float32)
xx, yy = torch.meshgrid(x_coordinates, y_coordinates, indexing="xy")
device = feature_map.device # Stack the coordinates and divide by num_patches
num_patches = feature_map.shape[1] box_coordinates = torch.stack((xx, yy), dim=-1)
box_coordinates /= num_patches
# TODO: Remove numpy usage.
box_coordinates = np.stack(
np.meshgrid(np.arange(1, num_patches + 1), np.arange(1, num_patches + 1)), axis=-1
).astype(np.float32)
box_coordinates /= np.array([num_patches, num_patches], np.float32)
# Flatten (h, w, 2) -> (h*w, 2) # Flatten (h, w, 2) -> (h*w, 2)
box_coordinates = box_coordinates.reshape( box_coordinates = box_coordinates.view(-1, 2)
box_coordinates.shape[0] * box_coordinates.shape[1], box_coordinates.shape[2]
)
box_coordinates = torch.from_numpy(box_coordinates).to(device)
return box_coordinates return box_coordinates
def compute_box_bias(self, feature_map: torch.FloatTensor) -> torch.FloatTensor: @lru_cache(maxsize=2)
def compute_box_bias(self, num_patches: int, feature_map: Optional[torch.FloatTensor] = None) -> torch.Tensor:
if feature_map is not None:
raise ValueError("feature_map has been deprecated as an input. Please pass in num_patches instead")
# The box center is biased to its position on the feature grid # The box center is biased to its position on the feature grid
box_coordinates = self.normalize_grid_corner_coordinates(feature_map) box_coordinates = self.normalize_grid_corner_coordinates(num_patches)
box_coordinates = torch.clip(box_coordinates, 0.0, 1.0) box_coordinates = torch.clip(box_coordinates, 0.0, 1.0)
# Unnormalize xy # Unnormalize xy
box_coord_bias = torch.log(box_coordinates + 1e-4) - torch.log1p(-box_coordinates + 1e-4) box_coord_bias = torch.log(box_coordinates + 1e-4) - torch.log1p(-box_coordinates + 1e-4)
# The box size is biased to the patch size # The box size is biased to the patch size
box_size = torch.full_like(box_coord_bias, 1.0 / feature_map.shape[-2]) box_size = torch.full_like(box_coord_bias, 1.0 / num_patches)
box_size_bias = torch.log(box_size + 1e-4) - torch.log1p(-box_size + 1e-4) box_size_bias = torch.log(box_size + 1e-4) - torch.log1p(-box_size + 1e-4)
# Compute box bias # Compute box bias
...@@ -1351,7 +1349,8 @@ class OwlViTForObjectDetection(OwlViTPreTrainedModel): ...@@ -1351,7 +1349,8 @@ class OwlViTForObjectDetection(OwlViTPreTrainedModel):
pred_boxes = self.box_head(image_feats) pred_boxes = self.box_head(image_feats)
# Compute the location of each token on the grid and use it to compute a bias for the bbox prediction # Compute the location of each token on the grid and use it to compute a bias for the bbox prediction
pred_boxes += self.compute_box_bias(feature_map) box_bias = self.box_bias.to(feature_map.device)
pred_boxes += box_bias
pred_boxes = self.sigmoid(pred_boxes) pred_boxes = self.sigmoid(pred_boxes)
return pred_boxes return pred_boxes
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment