Unverified Commit 5090ea3f authored by Fraser Mince's avatar Fraser Mince Committed by GitHub
Browse files

Fix llava half precision and autocast issues (#29721)

* Ensure input_embeds and image_features are the same dtype in autocast

* Fix nans in half precision llava-next and fix autocasting behavior.

* Fix styling issues.

* fix randn newline instantiation

* fix broken slow llava test

* Fix llava next init.

* fix styling issues

* [run-slow]llava,llava_next

* fix styling issues
parent d57ffb48
...@@ -438,6 +438,7 @@ class LlavaForConditionalGeneration(LlavaPreTrainedModel): ...@@ -438,6 +438,7 @@ class LlavaForConditionalGeneration(LlavaPreTrainedModel):
) )
image_features = self.multi_modal_projector(selected_image_feature) image_features = self.multi_modal_projector(selected_image_feature)
inputs_embeds = inputs_embeds.to(image_features.dtype)
inputs_embeds, attention_mask, labels, position_ids = self._merge_input_ids_with_image_features( inputs_embeds, attention_mask, labels, position_ids = self._merge_input_ids_with_image_features(
image_features, inputs_embeds, input_ids, attention_mask, labels image_features, inputs_embeds, input_ids, attention_mask, labels
) )
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
# limitations under the License. # limitations under the License.
""" PyTorch Llava-NeXT model.""" """ PyTorch Llava-NeXT model."""
import math
from dataclasses import dataclass from dataclasses import dataclass
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
...@@ -306,8 +307,8 @@ class LlavaNextForConditionalGeneration(LlavaNextPreTrainedModel): ...@@ -306,8 +307,8 @@ class LlavaNextForConditionalGeneration(LlavaNextPreTrainedModel):
self.vision_tower = AutoModel.from_config(config.vision_config) self.vision_tower = AutoModel.from_config(config.vision_config)
self.multi_modal_projector = LlavaNextMultiModalProjector(config) self.multi_modal_projector = LlavaNextMultiModalProjector(config)
embed_std = 1 / math.sqrt(config.text_config.hidden_size)
self.image_newline = nn.Parameter(torch.empty(config.text_config.hidden_size, dtype=self.dtype)) self.image_newline = nn.Parameter(torch.randn(config.text_config.hidden_size, dtype=self.dtype) * embed_std)
self.vocab_size = config.text_config.vocab_size self.vocab_size = config.text_config.vocab_size
self.language_model = AutoModelForCausalLM.from_config( self.language_model = AutoModelForCausalLM.from_config(
...@@ -543,7 +544,9 @@ class LlavaNextForConditionalGeneration(LlavaNextPreTrainedModel): ...@@ -543,7 +544,9 @@ class LlavaNextForConditionalGeneration(LlavaNextPreTrainedModel):
image_feature = torch.cat( image_feature = torch.cat(
( (
image_feature, image_feature,
self.image_newline[:, None, None].expand(*image_feature.shape[:-1], 1), self.image_newline[:, None, None]
.expand(*image_feature.shape[:-1], 1)
.to(image_feature.dtype),
), ),
dim=-1, dim=-1,
) )
...@@ -554,6 +557,7 @@ class LlavaNextForConditionalGeneration(LlavaNextPreTrainedModel): ...@@ -554,6 +557,7 @@ class LlavaNextForConditionalGeneration(LlavaNextPreTrainedModel):
image_feature = torch.cat((image_feature, self.image_newline[None]), dim=0) image_feature = torch.cat((image_feature, self.image_newline[None]), dim=0)
new_image_features.append(image_feature) new_image_features.append(image_feature)
image_features = torch.stack(new_image_features, dim=0) image_features = torch.stack(new_image_features, dim=0)
inputs_embeds = inputs_embeds.to(image_features.dtype)
inputs_embeds, attention_mask, labels, position_ids = self._merge_input_ids_with_image_features( inputs_embeds, attention_mask, labels, position_ids = self._merge_input_ids_with_image_features(
image_features, inputs_embeds, input_ids, attention_mask, labels image_features, inputs_embeds, input_ids, attention_mask, labels
......
...@@ -157,6 +157,19 @@ class LlavaVisionText2TextModelTester: ...@@ -157,6 +157,19 @@ class LlavaVisionText2TextModelTester:
} }
return config, inputs_dict return config, inputs_dict
def create_and_check_llava_model_fp16_forward(self, config, input_ids, pixel_values, attention_mask):
model = LlavaForConditionalGeneration(config=config)
model.to(torch_device)
model.eval()
with torch.autocast(device_type="cuda", dtype=torch.float16):
logits = model(
input_ids=input_ids,
attention_mask=attention_mask,
pixel_values=pixel_values.to(torch.bfloat16),
return_dict=True,
)["logits"]
self.parent.assertFalse(torch.isnan(logits).any().item())
@require_torch @require_torch
class LlavaForConditionalGenerationModelTest(ModelTesterMixin, unittest.TestCase): class LlavaForConditionalGenerationModelTest(ModelTesterMixin, unittest.TestCase):
...@@ -225,7 +238,7 @@ class LlavaForConditionalGenerationIntegrationTest(unittest.TestCase): ...@@ -225,7 +238,7 @@ class LlavaForConditionalGenerationIntegrationTest(unittest.TestCase):
@slow @slow
@require_bitsandbytes @require_bitsandbytes
def test_small_model_integration_test_llama(self): def test_small_model_integration_test_llama_single(self):
# Let' s make sure we test the preprocessing to replace what is used # Let' s make sure we test the preprocessing to replace what is used
model_id = "llava-hf/llava-1.5-7b-hf" model_id = "llava-hf/llava-1.5-7b-hf"
...@@ -238,7 +251,7 @@ class LlavaForConditionalGenerationIntegrationTest(unittest.TestCase): ...@@ -238,7 +251,7 @@ class LlavaForConditionalGenerationIntegrationTest(unittest.TestCase):
inputs = processor(prompt, raw_image, return_tensors="pt").to(torch_device, torch.float16) inputs = processor(prompt, raw_image, return_tensors="pt").to(torch_device, torch.float16)
output = model.generate(**inputs, max_new_tokens=900, do_sample=False) output = model.generate(**inputs, max_new_tokens=900, do_sample=False)
EXPECTED_DECODED_TEXT = "USER: \nWhat are the things I should be cautious about when I visit this place? ASSISTANT: When visiting this place, which is a pier or dock extending over a body of water, there are a few things to be cautious about. First, be aware of the weather conditions, as sudden changes in weather can make the pier unsafe to walk on. Second, be mindful of the water depth and any potential hazards, such as submerged rocks or debris, that could cause accidents or injuries. Additionally, be cautious of the tides and currents, as they can change rapidly and pose a risk to swimmers or those who venture too close to the edge of the pier. Finally, be respectful of the environment and other visitors, and follow any posted rules or guidelines for the area." # fmt: skip EXPECTED_DECODED_TEXT = "USER: \nWhat are the things I should be cautious about when I visit this place? ASSISTANT: When visiting this place, which is a pier or dock extending over a body of water, there are a few things to be cautious about. First, be aware of the weather conditions, as sudden changes in weather can make the pier unsafe to walk on. Second, be mindful of the water depth and any potential hazards, such as submerged rocks or debris, that could cause accidents or injuries. Additionally, be cautious of the tides and currents, as they can change rapidly and pose a risk to swimmers or those who venture too close to the edge of the pier. Lastly, be respectful of the environment and other visitors, as the pier is a shared space where people can enjoy the view, relax, or engage in recreational activities." # fmt: skip
self.assertEqual( self.assertEqual(
processor.decode(output[0], skip_special_tokens=True), processor.decode(output[0], skip_special_tokens=True),
...@@ -267,7 +280,10 @@ class LlavaForConditionalGenerationIntegrationTest(unittest.TestCase): ...@@ -267,7 +280,10 @@ class LlavaForConditionalGenerationIntegrationTest(unittest.TestCase):
EXPECTED_DECODED_TEXT = ['USER: \nWhat are the things I should be cautious about when I visit this place? What should I bring with me? ASSISTANT: When visiting this place, which is a pier or dock extending over a body of water, you', 'USER: \nWhat is this? ASSISTANT: The image features two cats lying down on a pink couch. One cat is located on'] # fmt: skip EXPECTED_DECODED_TEXT = ['USER: \nWhat are the things I should be cautious about when I visit this place? What should I bring with me? ASSISTANT: When visiting this place, which is a pier or dock extending over a body of water, you', 'USER: \nWhat is this? ASSISTANT: The image features two cats lying down on a pink couch. One cat is located on'] # fmt: skip
self.assertEqual(processor.batch_decode(output, skip_special_tokens=True), EXPECTED_DECODED_TEXT) self.assertEqual(
processor.batch_decode(output, skip_special_tokens=True),
EXPECTED_DECODED_TEXT,
)
@slow @slow
@require_bitsandbytes @require_bitsandbytes
...@@ -287,7 +303,10 @@ class LlavaForConditionalGenerationIntegrationTest(unittest.TestCase): ...@@ -287,7 +303,10 @@ class LlavaForConditionalGenerationIntegrationTest(unittest.TestCase):
output = model.generate(**inputs, max_new_tokens=20) output = model.generate(**inputs, max_new_tokens=20)
EXPECTED_DECODED_TEXT = ['USER: \nWhat are the things I should be cautious about when I visit this place? What should I bring with me?\nASSISTANT: When visiting this place, there are a few things to be cautious about and items to bring along', 'USER: \nWhat is this?\nASSISTANT: Cats'] # fmt: skip EXPECTED_DECODED_TEXT = ['USER: \nWhat are the things I should be cautious about when I visit this place? What should I bring with me?\nASSISTANT: When visiting this place, there are a few things to be cautious about and items to bring along', 'USER: \nWhat is this?\nASSISTANT: Cats'] # fmt: skip
self.assertEqual(self.processor.batch_decode(output, skip_special_tokens=True), EXPECTED_DECODED_TEXT) self.assertEqual(
self.processor.batch_decode(output, skip_special_tokens=True),
EXPECTED_DECODED_TEXT,
)
@slow @slow
@require_bitsandbytes @require_bitsandbytes
...@@ -314,7 +333,10 @@ class LlavaForConditionalGenerationIntegrationTest(unittest.TestCase): ...@@ -314,7 +333,10 @@ class LlavaForConditionalGenerationIntegrationTest(unittest.TestCase):
EXPECTED_DECODED_TEXT = ['USER: \nWhat are the things I should be cautious about when I visit this place? What should I bring with me?\nASSISTANT: When visiting this place, which appears to be a dock or pier extending over a body of water', 'USER: \nWhat is this?\nASSISTANT: Two cats lying on a bed!\nUSER: \nAnd this?\nASSISTANT: A cat sleeping on a bed.'] # fmt: skip EXPECTED_DECODED_TEXT = ['USER: \nWhat are the things I should be cautious about when I visit this place? What should I bring with me?\nASSISTANT: When visiting this place, which appears to be a dock or pier extending over a body of water', 'USER: \nWhat is this?\nASSISTANT: Two cats lying on a bed!\nUSER: \nAnd this?\nASSISTANT: A cat sleeping on a bed.'] # fmt: skip
self.assertEqual(processor.batch_decode(output, skip_special_tokens=True), EXPECTED_DECODED_TEXT) self.assertEqual(
processor.batch_decode(output, skip_special_tokens=True),
EXPECTED_DECODED_TEXT,
)
@slow @slow
@require_torch @require_torch
...@@ -342,7 +364,7 @@ class LlavaForConditionalGenerationIntegrationTest(unittest.TestCase): ...@@ -342,7 +364,7 @@ class LlavaForConditionalGenerationIntegrationTest(unittest.TestCase):
model = model.eval() model = model.eval()
EXPECTED_OUTPUT = [ EXPECTED_OUTPUT = [
"\n \nUSER: What's the the difference of two images?\nASSISTANT: In the two images, the primary difference is the presence of a small dog holding a flower in one", "\n \nUSER: What's the the difference of two images?\nASSISTANT: In the two images, the primary difference is the presence of a small dog in one and a ll",
"\nUSER: Describe the image.\nASSISTANT: The image features a small, fluffy dog sitting on a sidewalk. The dog is holding", "\nUSER: Describe the image.\nASSISTANT: The image features a small, fluffy dog sitting on a sidewalk. The dog is holding",
"\nUSER: Describe the image.\nASSISTANT: The image features a lone, adult llama standing on a grassy hill. The llama", "\nUSER: Describe the image.\nASSISTANT: The image features a lone, adult llama standing on a grassy hill. The llama",
] ]
......
...@@ -27,11 +27,21 @@ from transformers import ( ...@@ -27,11 +27,21 @@ from transformers import (
is_torch_available, is_torch_available,
is_vision_available, is_vision_available,
) )
from transformers.testing_utils import require_bitsandbytes, require_torch, slow, torch_device from transformers.testing_utils import (
require_bitsandbytes,
require_torch,
slow,
torch_device,
)
from ...generation.test_utils import GenerationTesterMixin from ...generation.test_utils import GenerationTesterMixin
from ...test_configuration_common import ConfigTester from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, _config_zero_init, floats_tensor, ids_tensor from ...test_modeling_common import (
ModelTesterMixin,
_config_zero_init,
floats_tensor,
ids_tensor,
)
if is_torch_available(): if is_torch_available():
...@@ -157,6 +167,39 @@ class LlavaNextVisionText2TextModelTester: ...@@ -157,6 +167,39 @@ class LlavaNextVisionText2TextModelTester:
} }
return config, inputs_dict return config, inputs_dict
def create_and_check_llava_next_model_fp16_forward(
self, config, input_ids, pixel_values, attention_mask, image_sizes
):
model = LlavaNextForConditionalGeneration(config=config)
model.to(torch_device)
model.half()
model.eval()
logits = model(
input_ids=input_ids,
attention_mask=attention_mask,
image_sizes=image_sizes,
pixel_values=pixel_values.to(torch.bfloat16),
return_dict=True,
)["logits"]
self.parent.assertFalse(torch.isnan(logits).any().item())
def create_and_check_llava_next_model_fp16_autocast_forward(
self, config, input_ids, pixel_values, attention_mask, image_sizes
):
config.torch_dtype = torch.float16
model = LlavaNextForConditionalGeneration(config=config)
model.to(torch_device)
model.eval()
with torch.autocast(device_type="cuda", dtype=torch.float16):
logits = model(
input_ids=input_ids,
attention_mask=attention_mask,
image_sizes=image_sizes,
pixel_values=pixel_values.to(torch.bfloat16),
return_dict=True,
)["logits"]
self.parent.assertFalse(torch.isnan(logits).any().item())
@require_torch @require_torch
class LlavaNextForConditionalGenerationModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): class LlavaNextForConditionalGenerationModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
...@@ -239,14 +282,20 @@ class LlavaNextForConditionalGenerationIntegrationTest(unittest.TestCase): ...@@ -239,14 +282,20 @@ class LlavaNextForConditionalGenerationIntegrationTest(unittest.TestCase):
inputs = self.processor(self.prompt, self.image, return_tensors="pt") inputs = self.processor(self.prompt, self.image, return_tensors="pt")
# verify inputs against original implementation # verify inputs against original implementation
filepath = hf_hub_download(repo_id="nielsr/test-image", filename="llava_1_6_input_ids.pt", repo_type="dataset") filepath = hf_hub_download(
repo_id="nielsr/test-image",
filename="llava_1_6_input_ids.pt",
repo_type="dataset",
)
original_input_ids = torch.load(filepath, map_location="cpu") original_input_ids = torch.load(filepath, map_location="cpu")
# replace -200 by image_token_index (since we use token ID = 32000 for the image token) # replace -200 by image_token_index (since we use token ID = 32000 for the image token)
original_input_ids[original_input_ids == -200] = model.config.image_token_index original_input_ids[original_input_ids == -200] = model.config.image_token_index
assert original_input_ids[0].tolist() == inputs.input_ids[0].tolist() assert original_input_ids[0].tolist() == inputs.input_ids[0].tolist()
filepath = hf_hub_download( filepath = hf_hub_download(
repo_id="nielsr/test-image", filename="llava_1_6_pixel_values.pt", repo_type="dataset" repo_id="nielsr/test-image",
filename="llava_1_6_pixel_values.pt",
repo_type="dataset",
) )
original_pixel_values = torch.load(filepath, map_location="cpu") original_pixel_values = torch.load(filepath, map_location="cpu")
assert torch.allclose(original_pixel_values, inputs.pixel_values.half()) assert torch.allclose(original_pixel_values, inputs.pixel_values.half())
...@@ -257,7 +306,11 @@ class LlavaNextForConditionalGenerationIntegrationTest(unittest.TestCase): ...@@ -257,7 +306,11 @@ class LlavaNextForConditionalGenerationIntegrationTest(unittest.TestCase):
output = model(**inputs) output = model(**inputs)
expected_slice = torch.tensor( expected_slice = torch.tensor(
[[-4.7695, -4.5664, -0.2786], [-10.6250, -10.8906, -2.5254], [-6.7383, -7.2461, -0.6787]], [
[-4.7695, -4.5664, -0.2786],
[-10.6250, -10.8906, -2.5254],
[-6.7383, -7.2461, -0.6787],
],
dtype=torch.float32, dtype=torch.float32,
device=torch_device, device=torch_device,
) )
...@@ -282,7 +335,10 @@ class LlavaNextForConditionalGenerationIntegrationTest(unittest.TestCase): ...@@ -282,7 +335,10 @@ class LlavaNextForConditionalGenerationIntegrationTest(unittest.TestCase):
cats_image = Image.open(requests.get(url, stream=True).raw) cats_image = Image.open(requests.get(url, stream=True).raw)
inputs = self.processor( inputs = self.processor(
[self.prompt, self.prompt], images=[self.image, cats_image], return_tensors="pt", padding=True [self.prompt, self.prompt],
images=[self.image, cats_image],
return_tensors="pt",
padding=True,
).to(torch_device) ).to(torch_device)
# make sure image_sizes are the same # make sure image_sizes are the same
...@@ -292,7 +348,10 @@ class LlavaNextForConditionalGenerationIntegrationTest(unittest.TestCase): ...@@ -292,7 +348,10 @@ class LlavaNextForConditionalGenerationIntegrationTest(unittest.TestCase):
output = model.generate(**inputs, max_new_tokens=20) output = model.generate(**inputs, max_new_tokens=20)
EXPECTED_DECODED_TEXT = ['[INST] \nWhat is shown in this image? [/INST] The image appears to be a radar chart, which is a type of multi-dimensional plot that displays', '[INST] \nWhat is shown in this image? [/INST] The image shows two cats lying on a pink surface, which appears to be a couch or a cush'] # fmt: skip EXPECTED_DECODED_TEXT = ['[INST] \nWhat is shown in this image? [/INST] The image appears to be a radar chart, which is a type of multi-dimensional plot that displays', '[INST] \nWhat is shown in this image? [/INST] The image shows two cats lying on a pink surface, which appears to be a couch or a cush'] # fmt: skip
self.assertEqual(self.processor.batch_decode(output, skip_special_tokens=True), EXPECTED_DECODED_TEXT) self.assertEqual(
self.processor.batch_decode(output, skip_special_tokens=True),
EXPECTED_DECODED_TEXT,
)
@slow @slow
@require_bitsandbytes @require_bitsandbytes
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment