"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "fa4eeb4fd342cdbad50d1eeacdd7d7d7bc23b080"
Unverified Commit e1cd7863 authored by Younes Belkada's avatar Younes Belkada Committed by GitHub
Browse files

[`BLIP`] fix doctest (#21217)



* fix `blip` doctest

* Update src/transformers/models/blip/modeling_blip.py
Co-authored-by: default avatarNielsRogge <48327001+NielsRogge@users.noreply.github.com>
Co-authored-by: default avatarNielsRogge <48327001+NielsRogge@users.noreply.github.com>
parent 4e730b38
...@@ -1176,17 +1176,30 @@ class BlipForQuestionAnswering(BlipPreTrainedModel): ...@@ -1176,17 +1176,30 @@ class BlipForQuestionAnswering(BlipPreTrainedModel):
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw) >>> image = Image.open(requests.get(url, stream=True).raw)
>>> text = "How many cats are in the picture?"
>>> # training
>>> text = "How many cats are in the picture?"
>>> label = "2"
>>> inputs = processor(images=image, text=text, return_tensors="pt") >>> inputs = processor(images=image, text=text, return_tensors="pt")
>>> labels = processor(text=label, return_tensors="pt").input_ids
>>> inputs["labels"] = labels
>>> outputs = model(**inputs) >>> outputs = model(**inputs)
>>> loss = outputs.loss
>>> loss.backward()
>>> # inference
>>> text = "How many cats are in the picture?"
>>> inputs = processor(images=image, text=text, return_tensors="pt")
>>> outputs = model.generate(**inputs)
>>> print(processor.decode(outputs[0], skip_special_tokens=True))
2
```""" ```"""
if labels is None and decoder_input_ids is None: if labels is None and decoder_input_ids is None:
raise ValueError( raise ValueError(
"Either `decoder_input_ids` or `labels` should be passed when calling `forward` with" "Either `decoder_input_ids` or `labels` should be passed when calling `forward` with"
" `BlipForQuestionAnswering`. if you are training the model make sure that `labels` is passed, if you" " `BlipForQuestionAnswering`. if you are training the model make sure that `labels` is passed, if you"
" are using the model for inference make sure that `decoder_input_ids` is passed." " are using the model for inference make sure that `decoder_input_ids` is passed or call `generate`"
) )
return_dict = return_dict if return_dict is not None else self.config.use_return_dict return_dict = return_dict if return_dict is not None else self.config.use_return_dict
...@@ -1392,8 +1405,8 @@ class BlipForImageTextRetrieval(BlipPreTrainedModel): ...@@ -1392,8 +1405,8 @@ class BlipForImageTextRetrieval(BlipPreTrainedModel):
>>> import requests >>> import requests
>>> from transformers import BlipProcessor, BlipForImageTextRetrieval >>> from transformers import BlipProcessor, BlipForImageTextRetrieval
>>> model = BlipForImageTextRetrieval.from_pretrained("Salesforce/blip-itm-base") >>> model = BlipForImageTextRetrieval.from_pretrained("Salesforce/blip-itm-base-coco")
>>> processor = BlipProcessor.from_pretrained("Salesforce/blip-itm-base") >>> processor = BlipProcessor.from_pretrained("Salesforce/blip-itm-base-coco")
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw) >>> image = Image.open(requests.get(url, stream=True).raw)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment