Unverified Commit 7a6efe1e authored by Stas Bekman's avatar Stas Bekman Committed by GitHub
Browse files

[idefics] idefics-9b test use 4bit quant (#25734)

parent fecf0856
...@@ -16,8 +16,15 @@ ...@@ -16,8 +16,15 @@
import unittest import unittest
from transformers import IdeficsConfig, is_torch_available, is_vision_available from transformers import BitsAndBytesConfig, IdeficsConfig, is_torch_available, is_vision_available
from transformers.testing_utils import TestCasePlus, require_torch, require_vision, slow, torch_device from transformers.testing_utils import (
TestCasePlus,
require_bitsandbytes,
require_torch,
require_vision,
slow,
torch_device,
)
from transformers.utils import cached_property from transformers.utils import cached_property
from ...test_configuration_common import ConfigTester from ...test_configuration_common import ConfigTester
...@@ -434,6 +441,7 @@ class IdeficsModelIntegrationTest(TestCasePlus): ...@@ -434,6 +441,7 @@ class IdeficsModelIntegrationTest(TestCasePlus):
def default_processor(self): def default_processor(self):
return IdeficsProcessor.from_pretrained("HuggingFaceM4/idefics-9b") if is_vision_available() else None return IdeficsProcessor.from_pretrained("HuggingFaceM4/idefics-9b") if is_vision_available() else None
@require_bitsandbytes
@slow @slow
def test_inference_natural_language_visual_reasoning(self): def test_inference_natural_language_visual_reasoning(self):
cat_image_path = self.tests_dir / "fixtures/tests_samples/COCO/000000039769.png" cat_image_path = self.tests_dir / "fixtures/tests_samples/COCO/000000039769.png"
...@@ -459,7 +467,14 @@ class IdeficsModelIntegrationTest(TestCasePlus): ...@@ -459,7 +467,14 @@ class IdeficsModelIntegrationTest(TestCasePlus):
], ],
] ]
model = IdeficsForVisionText2Text.from_pretrained("HuggingFaceM4/idefics-9b").to(torch_device) # the CI gpu is small so using quantization to fit
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype="float16",
)
model = IdeficsForVisionText2Text.from_pretrained(
"HuggingFaceM4/idefics-9b", quantization_config=quantization_config, device_map="auto"
)
processor = self.default_processor processor = self.default_processor
inputs = processor(prompts, return_tensors="pt").to(torch_device) inputs = processor(prompts, return_tensors="pt").to(torch_device)
generated_ids = model.generate(**inputs, max_length=100) generated_ids = model.generate(**inputs, max_length=100)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment