Unverified Commit a64bcb56 authored by Yih-Dar's avatar Yih-Dar Committed by GitHub
Browse files

Fix OwlViT torchscript tests (#18347)


Co-authored-by: default avatarydshieh <ydshieh@users.noreply.github.com>
parent a4ee463d
...@@ -1153,7 +1153,6 @@ class OwlViTClassPredictionHead(nn.Module): ...@@ -1153,7 +1153,6 @@ class OwlViTClassPredictionHead(nn.Module):
class OwlViTForObjectDetection(OwlViTPreTrainedModel): class OwlViTForObjectDetection(OwlViTPreTrainedModel):
config_class = OwlViTConfig config_class = OwlViTConfig
main_input_name = "pixel_values"
def __init__(self, config: OwlViTConfig): def __init__(self, config: OwlViTConfig):
super().__init__(config) super().__init__(config)
...@@ -1246,8 +1245,8 @@ class OwlViTForObjectDetection(OwlViTPreTrainedModel): ...@@ -1246,8 +1245,8 @@ class OwlViTForObjectDetection(OwlViTPreTrainedModel):
def image_text_embedder( def image_text_embedder(
self, self,
pixel_values: torch.FloatTensor,
input_ids: torch.Tensor, input_ids: torch.Tensor,
pixel_values: torch.FloatTensor,
attention_mask: torch.Tensor, attention_mask: torch.Tensor,
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
) -> torch.FloatTensor: ) -> torch.FloatTensor:
...@@ -1284,8 +1283,8 @@ class OwlViTForObjectDetection(OwlViTPreTrainedModel): ...@@ -1284,8 +1283,8 @@ class OwlViTForObjectDetection(OwlViTPreTrainedModel):
@replace_return_docstrings(output_type=OwlViTObjectDetectionOutput, config_class=OwlViTConfig) @replace_return_docstrings(output_type=OwlViTObjectDetectionOutput, config_class=OwlViTConfig)
def forward( def forward(
self, self,
pixel_values: torch.FloatTensor,
input_ids: torch.Tensor, input_ids: torch.Tensor,
pixel_values: torch.FloatTensor,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
...@@ -1338,8 +1337,8 @@ class OwlViTForObjectDetection(OwlViTPreTrainedModel): ...@@ -1338,8 +1337,8 @@ class OwlViTForObjectDetection(OwlViTPreTrainedModel):
if output_hidden_states: if output_hidden_states:
outputs = self.owlvit( outputs = self.owlvit(
pixel_values=pixel_values,
input_ids=input_ids, input_ids=input_ids,
pixel_values=pixel_values,
attention_mask=attention_mask, attention_mask=attention_mask,
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
...@@ -1350,8 +1349,8 @@ class OwlViTForObjectDetection(OwlViTPreTrainedModel): ...@@ -1350,8 +1349,8 @@ class OwlViTForObjectDetection(OwlViTPreTrainedModel):
# Embed images and text queries # Embed images and text queries
feature_map, query_embeds = self.image_text_embedder( feature_map, query_embeds = self.image_text_embedder(
pixel_values=pixel_values,
input_ids=input_ids, input_ids=input_ids,
pixel_values=pixel_values,
attention_mask=attention_mask, attention_mask=attention_mask,
output_attentions=output_attentions, output_attentions=output_attentions,
) )
...@@ -1374,7 +1373,7 @@ class OwlViTForObjectDetection(OwlViTPreTrainedModel): ...@@ -1374,7 +1373,7 @@ class OwlViTForObjectDetection(OwlViTPreTrainedModel):
pred_boxes = self.box_predictor(image_feats, feature_map) pred_boxes = self.box_predictor(image_feats, feature_map)
if not return_dict: if not return_dict:
return ( output = (
pred_logits, pred_logits,
pred_boxes, pred_boxes,
query_embeds, query_embeds,
...@@ -1383,6 +1382,8 @@ class OwlViTForObjectDetection(OwlViTPreTrainedModel): ...@@ -1383,6 +1382,8 @@ class OwlViTForObjectDetection(OwlViTPreTrainedModel):
text_model_last_hidden_states, text_model_last_hidden_states,
vision_model_last_hidden_states, vision_model_last_hidden_states,
) )
output = tuple(x for x in output if x is not None)
return output
return OwlViTObjectDetectionOutput( return OwlViTObjectDetectionOutput(
image_embeds=feature_map, image_embeds=feature_map,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment