Unverified Commit 00247ea0 authored by Jinho Park's avatar Jinho Park Committed by GitHub
Browse files

add bbox input validation (#26294)

parent 24553206
......@@ -877,6 +877,9 @@ class BrosModel(BrosPreTrainedModel):
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
if bbox is None:
raise ValueError("You have to specify bbox")
batch_size, seq_length = input_shape
device = input_ids.device if input_ids is not None else inputs_embeds.device
......@@ -924,13 +927,11 @@ class BrosModel(BrosPreTrainedModel):
past_key_values_length=past_key_values_length,
)
bbox_position_embeddings = None
if bbox is not None:
# if bbox has 2 points (4 float tensors) per token, convert it to 4 points (8 float tensors) per token
if bbox.shape[-1] == 4:
bbox = bbox[:, :, [0, 1, 2, 1, 2, 3, 0, 3]]
scaled_bbox = bbox * self.config.bbox_scale
bbox_position_embeddings = self.bbox_embeddings(scaled_bbox)
# if bbox has 2 points (4 float tensors) per token, convert it to 4 points (8 float tensors) per token
if bbox.shape[-1] == 4:
bbox = bbox[:, :, [0, 1, 2, 1, 2, 3, 0, 3]]
scaled_bbox = bbox * self.config.bbox_scale
bbox_position_embeddings = self.bbox_embeddings(scaled_bbox)
encoder_outputs = self.encoder(
embedding_output,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment