Unverified Commit fbdb978e authored by lewtun's avatar lewtun Committed by GitHub
Browse files

Fix Llava chat template examples (#30130)

parent b752ad30
...@@ -43,13 +43,13 @@ The original code can be found [here](https://github.com/haotian-liu/LLaVA/tree/ ...@@ -43,13 +43,13 @@ The original code can be found [here](https://github.com/haotian-liu/LLaVA/tree/
- For better results, we recommend users to prompt the model with the correct prompt format: - For better results, we recommend users to prompt the model with the correct prompt format:
```bash ```bash
"USER: <image>\n<prompt>ASSISTANT:" "USER: <image>\n<prompt> ASSISTANT:"
``` ```
For multiple turns conversation: For multiple turns conversation:
```bash ```bash
"USER: <image>\n<prompt1>ASSISTANT: <answer1>USER: <prompt2>ASSISTANT: <answer2>USER: <prompt3>ASSISTANT:" "USER: <image>\n<prompt1> ASSISTANT: <answer1></s>USER: <prompt2> ASSISTANT: <answer2></s>USER: <prompt3> ASSISTANT:"
``` ```
### Using Flash Attention 2 ### Using Flash Attention 2
......
...@@ -12,7 +12,8 @@ ...@@ -12,7 +12,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
""" PyTorch Llava model.""" """PyTorch Llava model."""
from dataclasses import dataclass from dataclasses import dataclass
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
...@@ -388,16 +389,16 @@ class LlavaForConditionalGeneration(LlavaPreTrainedModel): ...@@ -388,16 +389,16 @@ class LlavaForConditionalGeneration(LlavaPreTrainedModel):
>>> model = LlavaForConditionalGeneration.from_pretrained("llava-hf/llava-1.5-7b-hf") >>> model = LlavaForConditionalGeneration.from_pretrained("llava-hf/llava-1.5-7b-hf")
>>> processor = AutoProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf") >>> processor = AutoProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf")
>>> prompt = "<image>\nUSER: What's the content of the image?\nASSISTANT:" >>> prompt = "USER: <image>\nWhat's the content of the image? ASSISTANT:"
>>> url = "https://www.ilankelman.org/stopsigns/australia.jpg" >>> url = "https://www.ilankelman.org/stopsigns/australia.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw) >>> image = Image.open(requests.get(url, stream=True).raw)
>>> inputs = processor(text=prompt, images=image, return_tensors="pt") >>> inputs = processor(text=prompt, images=image, return_tensors="pt")
>>> # Generate >>> # Generate
>>> generate_ids = model.generate(**inputs, max_length=30) >>> generate_ids = model.generate(**inputs, max_new_tokens=15)
>>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"\nUSER: What's the content of the image?\nASSISTANT: The image features a stop sign on a street corner" "USER: \nWhat's the content of the image? ASSISTANT: The image features a busy city street with a stop sign prominently displayed"
```""" ```"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
......
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
""" Testing suite for the PyTorch Llava model. """ """Testing suite for the PyTorch Llava model."""
import copy import copy
import gc import gc
...@@ -398,13 +398,13 @@ class LlavaForConditionalGenerationIntegrationTest(unittest.TestCase): ...@@ -398,13 +398,13 @@ class LlavaForConditionalGenerationIntegrationTest(unittest.TestCase):
model = LlavaForConditionalGeneration.from_pretrained("llava-hf/llava-1.5-7b-hf", load_in_4bit=True) model = LlavaForConditionalGeneration.from_pretrained("llava-hf/llava-1.5-7b-hf", load_in_4bit=True)
processor = AutoProcessor.from_pretrained(model_id) processor = AutoProcessor.from_pretrained(model_id)
prompt = "USER: <image>\nWhat are the things I should be cautious about when I visit this place?\nASSISTANT:" prompt = "USER: <image>\nWhat are the things I should be cautious about when I visit this place? ASSISTANT:"
image_file = "https://llava-vl.github.io/static/images/view.jpg" image_file = "https://llava-vl.github.io/static/images/view.jpg"
raw_image = Image.open(requests.get(image_file, stream=True).raw) raw_image = Image.open(requests.get(image_file, stream=True).raw)
inputs = processor(prompt, raw_image, return_tensors="pt").to(torch_device, torch.float16) inputs = processor(prompt, raw_image, return_tensors="pt").to(torch_device, torch.float16)
output = model.generate(**inputs, max_new_tokens=900, do_sample=False) output = model.generate(**inputs, max_new_tokens=900, do_sample=False)
EXPECTED_DECODED_TEXT = "USER: \nWhat are the things I should be cautious about when I visit this place?\nASSISTANT: When visiting this place, which is a pier or dock extending over a body of water, there are a few things to be cautious about. First, be aware of the weather conditions, as sudden changes in weather can make the pier unsafe to walk on. Second, be mindful of the water depth and any potential hazards, such as submerged rocks or debris, that could cause accidents or injuries. Additionally, be cautious of the presence of wildlife, such as birds or fish, and avoid disturbing their natural habitats. Lastly, be aware of any local regulations or guidelines for the use of the pier, as some areas may be restricted or prohibited for certain activities." # fmt: skip EXPECTED_DECODED_TEXT = "USER: \nWhat are the things I should be cautious about when I visit this place? ASSISTANT: When visiting this place, which is a pier or dock extending over a body of water, there are a few things to be cautious about. First, be aware of the weather conditions, as sudden changes in weather can make the pier unsafe to walk on. Second, be mindful of the water depth and any potential hazards, such as submerged rocks or debris, that could cause accidents or injuries. Additionally, be cautious of the tides and currents, as they can change rapidly and pose a risk to swimmers or those who venture too close to the edge of the pier. Finally, be respectful of the environment and other visitors, and follow any posted rules or guidelines for the area." # fmt: skip
self.assertEqual( self.assertEqual(
processor.decode(output[0], skip_special_tokens=True), processor.decode(output[0], skip_special_tokens=True),
...@@ -421,8 +421,8 @@ class LlavaForConditionalGenerationIntegrationTest(unittest.TestCase): ...@@ -421,8 +421,8 @@ class LlavaForConditionalGenerationIntegrationTest(unittest.TestCase):
processor = AutoProcessor.from_pretrained(model_id) processor = AutoProcessor.from_pretrained(model_id)
prompts = [ prompts = [
"USER: <image>\nWhat are the things I should be cautious about when I visit this place? What should I bring with me?\nASSISTANT:", "USER: <image>\nWhat are the things I should be cautious about when I visit this place? What should I bring with me? ASSISTANT:",
"USER: <image>\nWhat is this?\nASSISTANT:", "USER: <image>\nWhat is this? ASSISTANT:",
] ]
image1 = Image.open(requests.get("https://llava-vl.github.io/static/images/view.jpg", stream=True).raw) image1 = Image.open(requests.get("https://llava-vl.github.io/static/images/view.jpg", stream=True).raw)
image2 = Image.open(requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw) image2 = Image.open(requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw)
...@@ -431,7 +431,7 @@ class LlavaForConditionalGenerationIntegrationTest(unittest.TestCase): ...@@ -431,7 +431,7 @@ class LlavaForConditionalGenerationIntegrationTest(unittest.TestCase):
output = model.generate(**inputs, max_new_tokens=20) output = model.generate(**inputs, max_new_tokens=20)
EXPECTED_DECODED_TEXT = ['USER: \nWhat are the things I should be cautious about when I visit this place? What should I bring with me?\nASSISTANT: When visiting this place, which appears to be a dock or pier extending over a body of water', 'USER: \nWhat is this?\nASSISTANT: The image features two cats lying down on a pink couch. One cat is located on'] # fmt: skip EXPECTED_DECODED_TEXT = ['USER: \nWhat are the things I should be cautious about when I visit this place? What should I bring with me? ASSISTANT: When visiting this place, which is a pier or dock extending over a body of water, you', 'USER: \nWhat is this? ASSISTANT: The image features two cats lying down on a pink couch. One cat is located on'] # fmt: skip
self.assertEqual(processor.batch_decode(output, skip_special_tokens=True), EXPECTED_DECODED_TEXT) self.assertEqual(processor.batch_decode(output, skip_special_tokens=True), EXPECTED_DECODED_TEXT)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment