"tests/models/auto/__init__.py" did not exist on "dd4df80f0b77c8f8e07e502298df0121cada9ce8"
Unverified Commit 3bc65505 authored by Yih-Dar's avatar Yih-Dar Committed by GitHub
Browse files

Fix doctest for `Blip2ForConditionalGeneration` (#26737)



* fix

* fix

* fix

* fix

* fix

* fix

* fix

* fix

* fix

---------
Co-authored-by: default avatarydshieh <ydshieh@users.noreply.github.com>
parent e1cec434
......@@ -1272,14 +1272,10 @@ class Blip2Model(Blip2PreTrainedModel):
>>> import torch
>>> from transformers import AutoTokenizer, Blip2Model
>>> device = "cuda" if torch.cuda.is_available() else "cpu"
>>> model = Blip2Model.from_pretrained("Salesforce/blip2-opt-2.7b", torch_dtype=torch.float16)
>>> model.to(device) # doctest: +IGNORE_RESULT
>>> model = Blip2Model.from_pretrained("Salesforce/blip2-opt-2.7b")
>>> tokenizer = AutoTokenizer.from_pretrained("Salesforce/blip2-opt-2.7b")
>>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt").to(device)
>>> inputs = tokenizer(["a photo of a cat"], padding=True, return_tensors="pt")
>>> text_features = model.get_text_features(**inputs)
```"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
......@@ -1333,16 +1329,12 @@ class Blip2Model(Blip2PreTrainedModel):
>>> import requests
>>> from transformers import AutoProcessor, Blip2Model
>>> device = "cuda" if torch.cuda.is_available() else "cpu"
>>> model = Blip2Model.from_pretrained("Salesforce/blip2-opt-2.7b", torch_dtype=torch.float16)
>>> model.to(device) # doctest: +IGNORE_RESULT
>>> model = Blip2Model.from_pretrained("Salesforce/blip2-opt-2.7b")
>>> processor = AutoProcessor.from_pretrained("Salesforce/blip2-opt-2.7b")
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> inputs = processor(images=image, return_tensors="pt").to(device, torch.float16)
>>> inputs = processor(images=image, return_tensors="pt")
>>> image_outputs = model.get_image_features(**inputs)
```"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
......@@ -1381,15 +1373,12 @@ class Blip2Model(Blip2PreTrainedModel):
>>> import requests
>>> from transformers import Blip2Processor, Blip2Model
>>> device = "cuda" if torch.cuda.is_available() else "cpu"
>>> processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b")
>>> model = Blip2Model.from_pretrained("Salesforce/blip2-opt-2.7b", torch_dtype=torch.float16)
>>> model.to(device) # doctest: +IGNORE_RESULT
>>> model = Blip2Model.from_pretrained("Salesforce/blip2-opt-2.7b")
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> inputs = processor(images=image, return_tensors="pt").to(device, torch.float16)
>>> inputs = processor(images=image, return_tensors="pt")
>>> qformer_outputs = model.get_qformer_features(**inputs)
```"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
......@@ -1654,7 +1643,7 @@ class Blip2ForConditionalGeneration(Blip2PreTrainedModel):
Examples:
Image captioning (without providing a text prompt):
Prepare processor, model and image input
```python
>>> from PIL import Image
......@@ -1666,13 +1655,16 @@ class Blip2ForConditionalGeneration(Blip2PreTrainedModel):
>>> processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b")
>>> model = Blip2ForConditionalGeneration.from_pretrained(
... "Salesforce/blip2-opt-2.7b", torch_dtype=torch.float16
... )
>>> model.to(device) # doctest: +IGNORE_RESULT
... "Salesforce/blip2-opt-2.7b", load_in_8bit=True, device_map={"": 0}, torch_dtype=torch.float16
... ) # doctest: +IGNORE_RESULT
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
```
Image captioning (without providing a text prompt):
```python
>>> inputs = processor(images=image, return_tensors="pt").to(device, torch.float16)
>>> generated_ids = model.generate(**inputs)
......@@ -1684,21 +1676,6 @@ class Blip2ForConditionalGeneration(Blip2PreTrainedModel):
Visual question answering (prompt = question):
```python
>>> from PIL import Image
>>> import requests
>>> from transformers import Blip2Processor, Blip2ForConditionalGeneration
>>> import torch
>>> device = "cuda" if torch.cuda.is_available() else "cpu"
>>> processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b")
>>> model = Blip2ForConditionalGeneration.from_pretrained(
... "Salesforce/blip2-opt-2.7b", load_in_8bit=True, device_map={"": 0}, torch_dtype=torch.float16
... ) # doctest: +IGNORE_RESULT
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> prompt = "Question: how many cats are there? Answer:"
>>> inputs = processor(images=image, text=prompt, return_tensors="pt").to(device="cuda", dtype=torch.float16)
......@@ -1712,20 +1689,10 @@ class Blip2ForConditionalGeneration(Blip2PreTrainedModel):
This greatly reduces the amount of memory used by the model while maintaining the same performance.
```python
>>> from PIL import Image
>>> import requests
>>> from transformers import Blip2Processor, Blip2ForConditionalGeneration
>>> import torch
>>> processor = Blip2Processor.from_pretrained("Salesforce/blip2-flan-t5-xl")
>>> model = Blip2ForConditionalGeneration.from_pretrained(
... "Salesforce/blip2-flan-t5-xl", load_in_8bit=True, device_map={"": 0}, torch_dtype=torch.bfloat16
... "Salesforce/blip2-opt-2.7b", load_in_8bit=True, device_map={"": 0}, torch_dtype=torch.bfloat16
... ) # doctest: +IGNORE_RESULT
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> prompt = "Question: how many cats are there? Answer:"
>>> inputs = processor(images=image, text=prompt, return_tensors="pt").to(device="cuda", dtype=torch.bfloat16)
>>> generated_ids = model.generate(**inputs)
......
docs/source/en/generation_strategies.md
docs/source/en/model_doc/ctrl.md
docs/source/en/task_summary.md
src/transformers/models/blip_2/modeling_blip_2.py
src/transformers/models/ctrl/modeling_ctrl.py
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment