Unverified Commit a64bcb56 authored by Yih-Dar's avatar Yih-Dar Committed by GitHub
Browse files

Fix OwlViT torchscript tests (#18347)


Co-authored-by: default avatarydshieh <ydshieh@users.noreply.github.com>
parent a4ee463d
......@@ -1153,7 +1153,6 @@ class OwlViTClassPredictionHead(nn.Module):
class OwlViTForObjectDetection(OwlViTPreTrainedModel):
config_class = OwlViTConfig
main_input_name = "pixel_values"
def __init__(self, config: OwlViTConfig):
super().__init__(config)
......@@ -1246,8 +1245,8 @@ class OwlViTForObjectDetection(OwlViTPreTrainedModel):
def image_text_embedder(
self,
pixel_values: torch.FloatTensor,
input_ids: torch.Tensor,
pixel_values: torch.FloatTensor,
attention_mask: torch.Tensor,
output_attentions: Optional[bool] = None,
) -> torch.FloatTensor:
......@@ -1284,8 +1283,8 @@ class OwlViTForObjectDetection(OwlViTPreTrainedModel):
@replace_return_docstrings(output_type=OwlViTObjectDetectionOutput, config_class=OwlViTConfig)
def forward(
self,
pixel_values: torch.FloatTensor,
input_ids: torch.Tensor,
pixel_values: torch.FloatTensor,
attention_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
......@@ -1338,8 +1337,8 @@ class OwlViTForObjectDetection(OwlViTPreTrainedModel):
if output_hidden_states:
outputs = self.owlvit(
pixel_values=pixel_values,
input_ids=input_ids,
pixel_values=pixel_values,
attention_mask=attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
......@@ -1350,8 +1349,8 @@ class OwlViTForObjectDetection(OwlViTPreTrainedModel):
# Embed images and text queries
feature_map, query_embeds = self.image_text_embedder(
pixel_values=pixel_values,
input_ids=input_ids,
pixel_values=pixel_values,
attention_mask=attention_mask,
output_attentions=output_attentions,
)
......@@ -1374,7 +1373,7 @@ class OwlViTForObjectDetection(OwlViTPreTrainedModel):
pred_boxes = self.box_predictor(image_feats, feature_map)
if not return_dict:
return (
output = (
pred_logits,
pred_boxes,
query_embeds,
......@@ -1383,6 +1382,8 @@ class OwlViTForObjectDetection(OwlViTPreTrainedModel):
text_model_last_hidden_states,
vision_model_last_hidden_states,
)
output = tuple(x for x in output if x is not None)
return output
return OwlViTObjectDetectionOutput(
image_embeds=feature_map,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment