Unverified Commit 44bd590a authored by Jungwoo Park's avatar Jungwoo Park Committed by GitHub
Browse files

Pix2Struct: fix wrong broadcast axis of attention mask in visual encoder (#23976)



* fix wrong broadcast axis of attention mask in visual encoder

* fix slow tests

---------
Co-authored-by: default avataryounesbelkada <younesbelkada@gmail.com>
parent 7824fa43
......@@ -210,7 +210,7 @@ class Pix2StructVisionAttention(nn.Module):
attention_mask = torch.ones((batch_size, seq_length), device=scores.device, dtype=scores.dtype)
if attention_mask.dim() == 2:
position_bias = position_bias + attention_mask[:, None, :, None].to(position_bias.device)
position_bias = position_bias + attention_mask[:, None, None, :].to(position_bias.device)
else:
# (batch_size, n_heads, seq_length, key_length)
position_bias = position_bias + attention_mask.to(position_bias.device)
......@@ -1695,7 +1695,7 @@ class Pix2StructForConditionalGeneration(Pix2StructPreTrainedModel):
>>> generated_ids = model.generate(**inputs, max_new_tokens=50)
>>> generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
>>> print(generated_text)
A picture of a stop sign with a red stop sign on it.
A picture of a stop sign with a red stop sign
```
Training:
......@@ -1719,7 +1719,7 @@ class Pix2StructForConditionalGeneration(Pix2StructPreTrainedModel):
>>> outputs = model(**inputs, labels=labels)
>>> loss = outputs.loss
>>> print(f"{loss.item():.5f}")
5.95566
5.94282
```"""
use_cache = use_cache if use_cache is not None else self.config.text_config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
......
......@@ -757,12 +757,12 @@ class Pix2StructIntegrationTest(unittest.TestCase):
self.assertEqual(
processor.decode(predictions[0], skip_special_tokens=True),
"A picture of a stop sign with a red stop sign on it.",
"A picture of a stop sign with a red stop sign",
)
self.assertEqual(
processor.decode(predictions[1], skip_special_tokens=True),
"An photography of the Temple Bar and the Temple Bar.",
"An photography of the Temple Bar and other places in the city.",
)
def test_vqa_model(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment