Unverified Commit 50f82e12 authored by Matt's avatar Matt Committed by GitHub
Browse files

Fix docstrings for TF BLIP (#22618)

* Fix docstrings for TFBLIP

* Fix missing line in TF port!

* Use values from torch tests now other bugs fixed

* Use values from torch tests now other bugs fixed

* Fix doctest string
parent ce06e478
......@@ -1020,7 +1020,7 @@ class TFBlipModel(TFBlipPreTrainedModel):
)
pooled_output = text_outputs[1]
text_features = self.text_projection(pooled_output)
text_features = self.blip.text_projection(pooled_output)
return text_features
......@@ -1057,7 +1057,7 @@ class TFBlipModel(TFBlipPreTrainedModel):
vision_outputs = self.blip.vision_model(pixel_values=pixel_values, return_dict=return_dict)
pooled_output = vision_outputs[1] # pooled_output
image_features = self.visual_projection(pooled_output)
image_features = self.blip.visual_projection(pooled_output)
return image_features
......@@ -1238,7 +1238,7 @@ class TFBlipForConditionalGeneration(TFBlipPreTrainedModel):
>>> outputs = model.generate(**inputs)
>>> print(processor.decode(outputs[0], skip_special_tokens=True))
two cats are laying on a couch
two cats sleeping on a couch
```
"""
......@@ -1410,7 +1410,6 @@ class TFBlipForQuestionAnswering(TFBlipPreTrainedModel):
>>> inputs["labels"] = labels
>>> outputs = model(**inputs)
>>> loss = outputs.loss
>>> loss.backward()
>>> # inference
>>> text = "How many cats are in the picture?"
......
......@@ -462,6 +462,7 @@ class TFBlipTextEncoder(tf.keras.layers.Layer):
next_decoder_cache += (layer_outputs[-1],)
if output_attentions:
all_self_attentions = all_self_attentions + (layer_outputs[1],)
all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
......
......@@ -783,7 +783,7 @@ class TFBlipModelIntegrationTest(unittest.TestCase):
# Test output
self.assertEqual(
predictions[0].numpy().tolist(),
[30522, 1037, 3861, 1997, 1037, 2450, 3564, 2006, 1996, 3509, 2007, 2014, 3899, 102],
[30522, 1037, 3861, 1997, 1037, 2450, 1998, 2014, 3899, 2006, 1996, 3509, 102],
)
def test_inference_vqa(self):
......@@ -810,6 +810,6 @@ class TFBlipModelIntegrationTest(unittest.TestCase):
out_itm = model(**inputs)
out = model(**inputs, use_itm_head=False, training=False)
expected_scores = tf.convert_to_tensor([[0.9798, 0.0202]])
expected_scores = tf.convert_to_tensor([[0.0029, 0.9971]])
self.assertTrue(np.allclose(tf.nn.softmax(out_itm[0]).numpy(), expected_scores, rtol=1e-3, atol=1e-3))
self.assertTrue(np.allclose(out[0], tf.convert_to_tensor([[0.5053]]), rtol=1e-3, atol=1e-3))
self.assertTrue(np.allclose(out[0], tf.convert_to_tensor([[0.5162]]), rtol=1e-3, atol=1e-3))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment