"docs/source/de/testing.md" did not exist on "1470f731b67117341993f024150fd8318b96d8b9"
Unverified Commit 5d81a568 authored by Alara Dirik's avatar Alara Dirik Committed by GitHub
Browse files

Owlvit memory leak fix (#18734)

* fix memory leak
* fix typos
* use singular last hidden state variable names
* eliminate double call to self.owlvit to return last hidden states
* eliminate 2nd call to self.vision_model in OwlViTModel
parent 80367cd1
......@@ -140,9 +140,9 @@ class OwlViTObjectDetectionOutput(ModelOutput):
class_embeds (`torch.FloatTensor` of shape `(batch_size, num_patches, hidden_size)`):
Class embeddings of all image patches. OWL-ViT represents images as a set of image patches where the total
number of patches is (image_size / patch_size)**2.
text_model_last_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`)):
text_model_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`)):
Last hidden states extracted from the [`OwlViTTextModel`].
vision_model_last_hidden_states (`torch.FloatTensor` of shape `(batch_size, num_patches + 1, hidden_size)`)):
vision_model_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_patches + 1, hidden_size)`)):
Last hidden states extracted from the [`OwlViTVisionModel`]. OWL-ViT represents images as a set of image
patches where the total number of patches is (image_size / patch_size)**2.
"""
......@@ -154,8 +154,8 @@ class OwlViTObjectDetectionOutput(ModelOutput):
text_embeds: torch.FloatTensor = None
image_embeds: torch.FloatTensor = None
class_embeds: torch.FloatTensor = None
text_model_last_hidden_states: Optional[torch.FloatTensor] = None
vision_model_last_hidden_states: Optional[torch.FloatTensor] = None
text_model_last_hidden_state: Optional[torch.FloatTensor] = None
vision_model_last_hidden_state: Optional[torch.FloatTensor] = None
class OwlViTVisionEmbeddings(nn.Module):
......@@ -516,6 +516,9 @@ OWLVIT_INPUTS_DOCSTRING = r"""
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
return_base_image_embeds (`bool`, *optional*):
Whether or not to return unprojected image embeddings. Set to `True` when `OwlViTModel` is called within
`OwlViTForObjectDetection`.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
......@@ -1013,6 +1016,7 @@ class OwlViTModel(OwlViTPreTrainedModel):
return_loss: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_base_image_embeds: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, OwlViTOutput]:
r"""
......@@ -1040,6 +1044,9 @@ class OwlViTModel(OwlViTPreTrainedModel):
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# Whether to return unprojected image features
return_base_image_embeds = return_base_image_embeds if return_base_image_embeds is not None else False
vision_outputs = self.vision_model(
pixel_values=pixel_values,
output_attentions=output_attentions,
......@@ -1075,6 +1082,10 @@ class OwlViTModel(OwlViTPreTrainedModel):
if return_loss:
loss = owlvit_loss(logits_per_text)
if return_base_image_embeds:
last_hidden_state = vision_outputs[0]
image_embeds = self.vision_model.post_layernorm(last_hidden_state)
if not return_dict:
output = (logits_per_image, logits_per_text, text_embeds, image_embeds, text_outputs, vision_outputs)
return ((loss,) + output) if loss is not None else output
......@@ -1170,15 +1181,15 @@ class OwlViTForObjectDetection(OwlViTPreTrainedModel):
def normalize_grid_corner_coordinates(self, feature_map: torch.FloatTensor):
# Computes normalized xy corner coordinates from feature_map.
if not feature_map.ndim == 4:
raise ValueError("Expected input shape is [batch_size, num_channels, height, width]")
raise ValueError("Expected input shape is [batch_size, num_patches, num_patches, hidden_dim]")
device = feature_map.device
height, width = feature_map.shape[1:3]
num_patches = feature_map.shape[1]
box_coordinates = np.stack(np.meshgrid(np.arange(1, width + 1), np.arange(1, height + 1)), axis=-1).astype(
np.float32
)
box_coordinates /= np.array([width, height], np.float32)
box_coordinates = np.stack(
np.meshgrid(np.arange(1, num_patches + 1), np.arange(1, num_patches + 1)), axis=-1
).astype(np.float32)
box_coordinates /= np.array([num_patches, num_patches], np.float32)
# Flatten (h, w, 2) -> (h*w, 2)
box_coordinates = box_coordinates.reshape(
......@@ -1232,7 +1243,7 @@ class OwlViTForObjectDetection(OwlViTPreTrainedModel):
image_feats: torch.FloatTensor,
query_embeds: torch.FloatTensor,
query_mask: torch.Tensor,
) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
) -> Tuple[torch.FloatTensor]:
"""
Args:
image_feats:
......@@ -1252,18 +1263,21 @@ class OwlViTForObjectDetection(OwlViTPreTrainedModel):
pixel_values: torch.FloatTensor,
attention_mask: torch.Tensor,
output_attentions: Optional[bool] = None,
) -> torch.FloatTensor:
# Encode text
text_embeds = self.owlvit.get_text_features(
input_ids=input_ids, attention_mask=attention_mask, output_attentions=output_attentions
)
output_hidden_states: Optional[bool] = None,
) -> Tuple[torch.FloatTensor]:
# Encode image
image_embeds = self.owlvit.get_image_features(
pixel_values, return_projected=False, output_attentions=output_attentions
# Encode text and image
outputs = self.owlvit(
pixel_values=pixel_values,
input_ids=input_ids,
attention_mask=attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_base_image_embeds=True,
)
# Resize class token
image_embeds = outputs.image_embeds
new_size = tuple(np.array(image_embeds.shape) - np.array((0, 1, 0)))
class_token_out = torch.broadcast_to(image_embeds[:, :1, :], new_size)
......@@ -1279,8 +1293,13 @@ class OwlViTForObjectDetection(OwlViTPreTrainedModel):
image_embeds.shape[-1],
)
image_embeds = image_embeds.reshape(new_size)
text_embeds = outputs.text_embeds
# Last hidden states from text and vision transformers
text_model_last_hidden_state = outputs.text_model_output.last_hidden_state
vision_model_last_hidden_state = outputs.vision_model_output.last_hidden_state
return (image_embeds, text_embeds)
return (text_embeds, image_embeds, text_model_last_hidden_state, vision_model_last_hidden_state)
@add_start_docstrings_to_model_forward(OWLVIT_OBJECT_DETECTION_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=OwlViTObjectDetectionOutput, config_class=OwlViTConfig)
......@@ -1334,12 +1353,8 @@ class OwlViTForObjectDetection(OwlViTPreTrainedModel):
)
return_dict = return_dict if return_dict is not None else self.config.return_dict
# Return last hidden states of text and vision transformers
text_model_last_hidden_states = None
vision_model_last_hidden_states = None
if output_hidden_states:
outputs = self.owlvit(
# Embed images and text queries
outputs = self.image_text_embedder(
input_ids=input_ids,
pixel_values=pixel_values,
attention_mask=attention_mask,
......@@ -1347,19 +1362,15 @@ class OwlViTForObjectDetection(OwlViTPreTrainedModel):
output_hidden_states=output_hidden_states,
)
text_model_last_hidden_states = outputs[-2][0]
vision_model_last_hidden_states = outputs[-1][0]
# Last hidden states of text and vision transformers
text_model_last_hidden_state = outputs[2]
vision_model_last_hidden_state = outputs[3]
# Embed images and text queries
feature_map, query_embeds = self.image_text_embedder(
input_ids=input_ids,
pixel_values=pixel_values,
attention_mask=attention_mask,
output_attentions=output_attentions,
)
query_embeds = outputs[0]
feature_map = outputs[1]
batch_size, height, width, hidden_dim = feature_map.shape
image_feats = torch.reshape(feature_map, (batch_size, height * width, hidden_dim))
batch_size, num_patches, num_patches, hidden_dim = feature_map.shape
image_feats = torch.reshape(feature_map, (batch_size, num_patches * num_patches, hidden_dim))
# Reshape from [batch_size * max_text_queries, hidden_dim] -> [batch_size, max_text_queries, hidden_dim]
max_text_queries = input_ids.shape[0] // batch_size
......@@ -1382,8 +1393,8 @@ class OwlViTForObjectDetection(OwlViTPreTrainedModel):
query_embeds,
feature_map,
class_embeds,
text_model_last_hidden_states,
vision_model_last_hidden_states,
text_model_last_hidden_state,
vision_model_last_hidden_state,
)
output = tuple(x for x in output if x is not None)
return output
......@@ -1394,6 +1405,6 @@ class OwlViTForObjectDetection(OwlViTPreTrainedModel):
pred_boxes=pred_boxes,
logits=pred_logits,
class_embeds=class_embeds,
text_model_last_hidden_states=text_model_last_hidden_states,
vision_model_last_hidden_states=vision_model_last_hidden_states,
text_model_last_hidden_state=text_model_last_hidden_state,
vision_model_last_hidden_state=vision_model_last_hidden_state,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment