"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "74a3cebfa51b539bfcfa79b33686cc090b7074e8"
Unverified Commit 44bd590a authored by Jungwoo Park's avatar Jungwoo Park Committed by GitHub
Browse files

Pix2Struct: fix wrong broadcast axis of attention mask in visual encoder (#23976)



* fix wrong broadcast axis of attention mask in visual encoder

* fix slow tests

---------
Co-authored-by: default avataryounesbelkada <younesbelkada@gmail.com>
parent 7824fa43
...@@ -210,7 +210,7 @@ class Pix2StructVisionAttention(nn.Module): ...@@ -210,7 +210,7 @@ class Pix2StructVisionAttention(nn.Module):
attention_mask = torch.ones((batch_size, seq_length), device=scores.device, dtype=scores.dtype) attention_mask = torch.ones((batch_size, seq_length), device=scores.device, dtype=scores.dtype)
if attention_mask.dim() == 2: if attention_mask.dim() == 2:
position_bias = position_bias + attention_mask[:, None, :, None].to(position_bias.device) position_bias = position_bias + attention_mask[:, None, None, :].to(position_bias.device)
else: else:
# (batch_size, n_heads, seq_length, key_length) # (batch_size, n_heads, seq_length, key_length)
position_bias = position_bias + attention_mask.to(position_bias.device) position_bias = position_bias + attention_mask.to(position_bias.device)
...@@ -1695,7 +1695,7 @@ class Pix2StructForConditionalGeneration(Pix2StructPreTrainedModel): ...@@ -1695,7 +1695,7 @@ class Pix2StructForConditionalGeneration(Pix2StructPreTrainedModel):
>>> generated_ids = model.generate(**inputs, max_new_tokens=50) >>> generated_ids = model.generate(**inputs, max_new_tokens=50)
>>> generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] >>> generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
>>> print(generated_text) >>> print(generated_text)
A picture of a stop sign with a red stop sign on it. A picture of a stop sign with a red stop sign
``` ```
Training: Training:
...@@ -1719,7 +1719,7 @@ class Pix2StructForConditionalGeneration(Pix2StructPreTrainedModel): ...@@ -1719,7 +1719,7 @@ class Pix2StructForConditionalGeneration(Pix2StructPreTrainedModel):
>>> outputs = model(**inputs, labels=labels) >>> outputs = model(**inputs, labels=labels)
>>> loss = outputs.loss >>> loss = outputs.loss
>>> print(f"{loss.item():.5f}") >>> print(f"{loss.item():.5f}")
5.95566 5.94282
```""" ```"""
use_cache = use_cache if use_cache is not None else self.config.text_config.use_cache use_cache = use_cache if use_cache is not None else self.config.text_config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict return_dict = return_dict if return_dict is not None else self.config.use_return_dict
......
...@@ -757,12 +757,12 @@ class Pix2StructIntegrationTest(unittest.TestCase): ...@@ -757,12 +757,12 @@ class Pix2StructIntegrationTest(unittest.TestCase):
self.assertEqual( self.assertEqual(
processor.decode(predictions[0], skip_special_tokens=True), processor.decode(predictions[0], skip_special_tokens=True),
"A picture of a stop sign with a red stop sign on it.", "A picture of a stop sign with a red stop sign",
) )
self.assertEqual( self.assertEqual(
processor.decode(predictions[1], skip_special_tokens=True), processor.decode(predictions[1], skip_special_tokens=True),
"An photography of the Temple Bar and the Temple Bar.", "An photography of the Temple Bar and other places in the city.",
) )
def test_vqa_model(self): def test_vqa_model(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment