Unverified Commit 50db7ca4 authored by Younes Belkada's avatar Younes Belkada Committed by GitHub
Browse files

FIX [`quantization` / `ESM`] Fix ESM 8bit / 4bit with bitsandbytes (#29329)



* fix ESM 8bit

* Apply suggestions from code review
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* fixup

---------
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>
parent 2858d6c6
...@@ -377,7 +377,7 @@ class EsmSelfAttention(nn.Module): ...@@ -377,7 +377,7 @@ class EsmSelfAttention(nn.Module):
if head_mask is not None: if head_mask is not None:
attention_probs = attention_probs * head_mask attention_probs = attention_probs * head_mask
context_layer = torch.matmul(attention_probs, value_layer) context_layer = torch.matmul(attention_probs.to(value_layer.dtype), value_layer)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous() context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
......
...@@ -121,7 +121,7 @@ class Bnb4BitHfQuantizer(HfQuantizer): ...@@ -121,7 +121,7 @@ class Bnb4BitHfQuantizer(HfQuantizer):
import bitsandbytes as bnb import bitsandbytes as bnb
module, tensor_name = get_module_from_name(model, param_name) module, tensor_name = get_module_from_name(model, param_name)
if isinstance(module._parameters[tensor_name], bnb.nn.Params4bit): if isinstance(module._parameters.get(tensor_name, None), bnb.nn.Params4bit):
# Add here check for loaded components' dtypes once serialization is implemented # Add here check for loaded components' dtypes once serialization is implemented
return True return True
elif isinstance(module, bnb.nn.Linear4bit) and tensor_name == "bias": elif isinstance(module, bnb.nn.Linear4bit) and tensor_name == "bias":
......
...@@ -139,7 +139,7 @@ class Bnb8BitHfQuantizer(HfQuantizer): ...@@ -139,7 +139,7 @@ class Bnb8BitHfQuantizer(HfQuantizer):
import bitsandbytes as bnb import bitsandbytes as bnb
module, tensor_name = get_module_from_name(model, param_name) module, tensor_name = get_module_from_name(model, param_name)
if isinstance(module._parameters[tensor_name], bnb.nn.Int8Params): if isinstance(module._parameters.get(tensor_name, None), bnb.nn.Int8Params):
if self.pre_quantized: if self.pre_quantized:
if param_name.replace("weight", "SCB") not in state_dict.keys(): if param_name.replace("weight", "SCB") not in state_dict.keys():
raise ValueError("Missing quantization component `SCB`") raise ValueError("Missing quantization component `SCB`")
......
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
import unittest import unittest
from transformers import EsmConfig, is_torch_available from transformers import EsmConfig, is_torch_available
from transformers.testing_utils import TestCasePlus, require_torch, slow, torch_device from transformers.testing_utils import TestCasePlus, require_bitsandbytes, require_torch, slow, torch_device
from ...test_configuration_common import ConfigTester from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask from ...test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask
...@@ -303,9 +303,9 @@ class EsmModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase): ...@@ -303,9 +303,9 @@ class EsmModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
pass pass
@slow
@require_torch @require_torch
class EsmModelIntegrationTest(TestCasePlus): class EsmModelIntegrationTest(TestCasePlus):
@slow
def test_inference_masked_lm(self): def test_inference_masked_lm(self):
with torch.no_grad(): with torch.no_grad():
model = EsmForMaskedLM.from_pretrained("facebook/esm2_t6_8M_UR50D") model = EsmForMaskedLM.from_pretrained("facebook/esm2_t6_8M_UR50D")
...@@ -323,7 +323,6 @@ class EsmModelIntegrationTest(TestCasePlus): ...@@ -323,7 +323,6 @@ class EsmModelIntegrationTest(TestCasePlus):
) )
self.assertTrue(torch.allclose(output[:, :3, :3], expected_slice, atol=1e-4)) self.assertTrue(torch.allclose(output[:, :3, :3], expected_slice, atol=1e-4))
@slow
def test_inference_no_head(self): def test_inference_no_head(self):
with torch.no_grad(): with torch.no_grad():
model = EsmModel.from_pretrained("facebook/esm2_t6_8M_UR50D") model = EsmModel.from_pretrained("facebook/esm2_t6_8M_UR50D")
...@@ -336,3 +335,18 @@ class EsmModelIntegrationTest(TestCasePlus): ...@@ -336,3 +335,18 @@ class EsmModelIntegrationTest(TestCasePlus):
[[[0.1444, 0.5413, 0.3248], [0.3034, 0.0053, 0.3108], [0.3228, -0.2499, 0.3415]]] [[[0.1444, 0.5413, 0.3248], [0.3034, 0.0053, 0.3108], [0.3228, -0.2499, 0.3415]]]
) )
self.assertTrue(torch.allclose(output[:, :3, :3], expected_slice, atol=1e-4)) self.assertTrue(torch.allclose(output[:, :3, :3], expected_slice, atol=1e-4))
@require_bitsandbytes
def test_inference_bitsandbytes(self):
model = EsmForMaskedLM.from_pretrained("facebook/esm2_t36_3B_UR50D", load_in_8bit=True)
input_ids = torch.tensor([[0, 6, 4, 13, 5, 4, 16, 12, 11, 7, 2]])
# Just test if inference works
with torch.no_grad():
_ = model(input_ids)[0]
model = EsmForMaskedLM.from_pretrained("facebook/esm2_t36_3B_UR50D", load_in_4bit=True)
input_ids = torch.tensor([[0, 6, 4, 13, 5, 4, 16, 12, 11, 7, 2]])
# Just test if inference works
_ = model(input_ids)[0]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment