"docs/source/vscode:/vscode.git/clone" did not exist on "ba2a5f13f777e828cbabbee213dcc841dfef3d05"
Unverified Commit 5d81a568 authored by Alara Dirik's avatar Alara Dirik Committed by GitHub
Browse files

Owlvit memory leak fix (#18734)

* fix memory leak
* fix typos
* use singular last hidden state variable names
* eliminate double call to self.owlvit to return last hidden states
* eliminate 2nd call to self.vision_model in OwlViTModel
parent 80367cd1
...@@ -140,9 +140,9 @@ class OwlViTObjectDetectionOutput(ModelOutput): ...@@ -140,9 +140,9 @@ class OwlViTObjectDetectionOutput(ModelOutput):
class_embeds (`torch.FloatTensor` of shape `(batch_size, num_patches, hidden_size)`): class_embeds (`torch.FloatTensor` of shape `(batch_size, num_patches, hidden_size)`):
Class embeddings of all image patches. OWL-ViT represents images as a set of image patches where the total Class embeddings of all image patches. OWL-ViT represents images as a set of image patches where the total
number of patches is (image_size / patch_size)**2. number of patches is (image_size / patch_size)**2.
text_model_last_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`)): text_model_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`)):
Last hidden states extracted from the [`OwlViTTextModel`]. Last hidden states extracted from the [`OwlViTTextModel`].
vision_model_last_hidden_states (`torch.FloatTensor` of shape `(batch_size, num_patches + 1, hidden_size)`)): vision_model_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_patches + 1, hidden_size)`)):
Last hidden states extracted from the [`OwlViTVisionModel`]. OWL-ViT represents images as a set of image Last hidden states extracted from the [`OwlViTVisionModel`]. OWL-ViT represents images as a set of image
patches where the total number of patches is (image_size / patch_size)**2. patches where the total number of patches is (image_size / patch_size)**2.
""" """
...@@ -154,8 +154,8 @@ class OwlViTObjectDetectionOutput(ModelOutput): ...@@ -154,8 +154,8 @@ class OwlViTObjectDetectionOutput(ModelOutput):
text_embeds: torch.FloatTensor = None text_embeds: torch.FloatTensor = None
image_embeds: torch.FloatTensor = None image_embeds: torch.FloatTensor = None
class_embeds: torch.FloatTensor = None class_embeds: torch.FloatTensor = None
text_model_last_hidden_states: Optional[torch.FloatTensor] = None text_model_last_hidden_state: Optional[torch.FloatTensor] = None
vision_model_last_hidden_states: Optional[torch.FloatTensor] = None vision_model_last_hidden_state: Optional[torch.FloatTensor] = None
class OwlViTVisionEmbeddings(nn.Module): class OwlViTVisionEmbeddings(nn.Module):
...@@ -516,6 +516,9 @@ OWLVIT_INPUTS_DOCSTRING = r""" ...@@ -516,6 +516,9 @@ OWLVIT_INPUTS_DOCSTRING = r"""
output_hidden_states (`bool`, *optional*): output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail. more detail.
return_base_image_embeds (`bool`, *optional*):
Whether or not to return unprojected image embeddings. Set to `True` when `OwlViTModel` is called within
`OwlViTForObjectDetection`.
return_dict (`bool`, *optional*): return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
""" """
...@@ -1013,6 +1016,7 @@ class OwlViTModel(OwlViTPreTrainedModel): ...@@ -1013,6 +1016,7 @@ class OwlViTModel(OwlViTPreTrainedModel):
return_loss: Optional[bool] = None, return_loss: Optional[bool] = None,
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
return_base_image_embeds: Optional[bool] = None,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
) -> Union[Tuple, OwlViTOutput]: ) -> Union[Tuple, OwlViTOutput]:
r""" r"""
...@@ -1040,6 +1044,9 @@ class OwlViTModel(OwlViTPreTrainedModel): ...@@ -1040,6 +1044,9 @@ class OwlViTModel(OwlViTPreTrainedModel):
) )
return_dict = return_dict if return_dict is not None else self.config.use_return_dict return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# Whether to return unprojected image features
return_base_image_embeds = return_base_image_embeds if return_base_image_embeds is not None else False
vision_outputs = self.vision_model( vision_outputs = self.vision_model(
pixel_values=pixel_values, pixel_values=pixel_values,
output_attentions=output_attentions, output_attentions=output_attentions,
...@@ -1075,6 +1082,10 @@ class OwlViTModel(OwlViTPreTrainedModel): ...@@ -1075,6 +1082,10 @@ class OwlViTModel(OwlViTPreTrainedModel):
if return_loss: if return_loss:
loss = owlvit_loss(logits_per_text) loss = owlvit_loss(logits_per_text)
if return_base_image_embeds:
last_hidden_state = vision_outputs[0]
image_embeds = self.vision_model.post_layernorm(last_hidden_state)
if not return_dict: if not return_dict:
output = (logits_per_image, logits_per_text, text_embeds, image_embeds, text_outputs, vision_outputs) output = (logits_per_image, logits_per_text, text_embeds, image_embeds, text_outputs, vision_outputs)
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
...@@ -1170,15 +1181,15 @@ class OwlViTForObjectDetection(OwlViTPreTrainedModel): ...@@ -1170,15 +1181,15 @@ class OwlViTForObjectDetection(OwlViTPreTrainedModel):
def normalize_grid_corner_coordinates(self, feature_map: torch.FloatTensor): def normalize_grid_corner_coordinates(self, feature_map: torch.FloatTensor):
# Computes normalized xy corner coordinates from feature_map. # Computes normalized xy corner coordinates from feature_map.
if not feature_map.ndim == 4: if not feature_map.ndim == 4:
raise ValueError("Expected input shape is [batch_size, num_channels, height, width]") raise ValueError("Expected input shape is [batch_size, num_patches, num_patches, hidden_dim]")
device = feature_map.device device = feature_map.device
height, width = feature_map.shape[1:3] num_patches = feature_map.shape[1]
box_coordinates = np.stack(np.meshgrid(np.arange(1, width + 1), np.arange(1, height + 1)), axis=-1).astype( box_coordinates = np.stack(
np.float32 np.meshgrid(np.arange(1, num_patches + 1), np.arange(1, num_patches + 1)), axis=-1
) ).astype(np.float32)
box_coordinates /= np.array([width, height], np.float32) box_coordinates /= np.array([num_patches, num_patches], np.float32)
# Flatten (h, w, 2) -> (h*w, 2) # Flatten (h, w, 2) -> (h*w, 2)
box_coordinates = box_coordinates.reshape( box_coordinates = box_coordinates.reshape(
...@@ -1232,7 +1243,7 @@ class OwlViTForObjectDetection(OwlViTPreTrainedModel): ...@@ -1232,7 +1243,7 @@ class OwlViTForObjectDetection(OwlViTPreTrainedModel):
image_feats: torch.FloatTensor, image_feats: torch.FloatTensor,
query_embeds: torch.FloatTensor, query_embeds: torch.FloatTensor,
query_mask: torch.Tensor, query_mask: torch.Tensor,
) -> Tuple[torch.FloatTensor, torch.FloatTensor]: ) -> Tuple[torch.FloatTensor]:
""" """
Args: Args:
image_feats: image_feats:
...@@ -1252,18 +1263,21 @@ class OwlViTForObjectDetection(OwlViTPreTrainedModel): ...@@ -1252,18 +1263,21 @@ class OwlViTForObjectDetection(OwlViTPreTrainedModel):
pixel_values: torch.FloatTensor, pixel_values: torch.FloatTensor,
attention_mask: torch.Tensor, attention_mask: torch.Tensor,
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
) -> torch.FloatTensor: output_hidden_states: Optional[bool] = None,
# Encode text ) -> Tuple[torch.FloatTensor]:
text_embeds = self.owlvit.get_text_features(
input_ids=input_ids, attention_mask=attention_mask, output_attentions=output_attentions
)
# Encode image # Encode text and image
image_embeds = self.owlvit.get_image_features( outputs = self.owlvit(
pixel_values, return_projected=False, output_attentions=output_attentions pixel_values=pixel_values,
input_ids=input_ids,
attention_mask=attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_base_image_embeds=True,
) )
# Resize class token # Resize class token
image_embeds = outputs.image_embeds
new_size = tuple(np.array(image_embeds.shape) - np.array((0, 1, 0))) new_size = tuple(np.array(image_embeds.shape) - np.array((0, 1, 0)))
class_token_out = torch.broadcast_to(image_embeds[:, :1, :], new_size) class_token_out = torch.broadcast_to(image_embeds[:, :1, :], new_size)
...@@ -1279,8 +1293,13 @@ class OwlViTForObjectDetection(OwlViTPreTrainedModel): ...@@ -1279,8 +1293,13 @@ class OwlViTForObjectDetection(OwlViTPreTrainedModel):
image_embeds.shape[-1], image_embeds.shape[-1],
) )
image_embeds = image_embeds.reshape(new_size) image_embeds = image_embeds.reshape(new_size)
text_embeds = outputs.text_embeds
# Last hidden states from text and vision transformers
text_model_last_hidden_state = outputs.text_model_output.last_hidden_state
vision_model_last_hidden_state = outputs.vision_model_output.last_hidden_state
return (image_embeds, text_embeds) return (text_embeds, image_embeds, text_model_last_hidden_state, vision_model_last_hidden_state)
@add_start_docstrings_to_model_forward(OWLVIT_OBJECT_DETECTION_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(OWLVIT_OBJECT_DETECTION_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=OwlViTObjectDetectionOutput, config_class=OwlViTConfig) @replace_return_docstrings(output_type=OwlViTObjectDetectionOutput, config_class=OwlViTConfig)
...@@ -1334,12 +1353,8 @@ class OwlViTForObjectDetection(OwlViTPreTrainedModel): ...@@ -1334,12 +1353,8 @@ class OwlViTForObjectDetection(OwlViTPreTrainedModel):
) )
return_dict = return_dict if return_dict is not None else self.config.return_dict return_dict = return_dict if return_dict is not None else self.config.return_dict
# Return last hidden states of text and vision transformers # Embed images and text queries
text_model_last_hidden_states = None outputs = self.image_text_embedder(
vision_model_last_hidden_states = None
if output_hidden_states:
outputs = self.owlvit(
input_ids=input_ids, input_ids=input_ids,
pixel_values=pixel_values, pixel_values=pixel_values,
attention_mask=attention_mask, attention_mask=attention_mask,
...@@ -1347,19 +1362,15 @@ class OwlViTForObjectDetection(OwlViTPreTrainedModel): ...@@ -1347,19 +1362,15 @@ class OwlViTForObjectDetection(OwlViTPreTrainedModel):
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
) )
text_model_last_hidden_states = outputs[-2][0] # Last hidden states of text and vision transformers
vision_model_last_hidden_states = outputs[-1][0] text_model_last_hidden_state = outputs[2]
vision_model_last_hidden_state = outputs[3]
# Embed images and text queries query_embeds = outputs[0]
feature_map, query_embeds = self.image_text_embedder( feature_map = outputs[1]
input_ids=input_ids,
pixel_values=pixel_values,
attention_mask=attention_mask,
output_attentions=output_attentions,
)
batch_size, height, width, hidden_dim = feature_map.shape batch_size, num_patches, num_patches, hidden_dim = feature_map.shape
image_feats = torch.reshape(feature_map, (batch_size, height * width, hidden_dim)) image_feats = torch.reshape(feature_map, (batch_size, num_patches * num_patches, hidden_dim))
# Reshape from [batch_size * max_text_queries, hidden_dim] -> [batch_size, max_text_queries, hidden_dim] # Reshape from [batch_size * max_text_queries, hidden_dim] -> [batch_size, max_text_queries, hidden_dim]
max_text_queries = input_ids.shape[0] // batch_size max_text_queries = input_ids.shape[0] // batch_size
...@@ -1382,8 +1393,8 @@ class OwlViTForObjectDetection(OwlViTPreTrainedModel): ...@@ -1382,8 +1393,8 @@ class OwlViTForObjectDetection(OwlViTPreTrainedModel):
query_embeds, query_embeds,
feature_map, feature_map,
class_embeds, class_embeds,
text_model_last_hidden_states, text_model_last_hidden_state,
vision_model_last_hidden_states, vision_model_last_hidden_state,
) )
output = tuple(x for x in output if x is not None) output = tuple(x for x in output if x is not None)
return output return output
...@@ -1394,6 +1405,6 @@ class OwlViTForObjectDetection(OwlViTPreTrainedModel): ...@@ -1394,6 +1405,6 @@ class OwlViTForObjectDetection(OwlViTPreTrainedModel):
pred_boxes=pred_boxes, pred_boxes=pred_boxes,
logits=pred_logits, logits=pred_logits,
class_embeds=class_embeds, class_embeds=class_embeds,
text_model_last_hidden_states=text_model_last_hidden_states, text_model_last_hidden_state=text_model_last_hidden_state,
vision_model_last_hidden_states=vision_model_last_hidden_states, vision_model_last_hidden_state=vision_model_last_hidden_state,
) )
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment