Unverified Commit c9a06714 authored by Younes Belkada's avatar Younes Belkada Committed by GitHub
Browse files

[`bnb`] fix `bnb` decoders bug (#21688)

* fix `bnb` decoders bug

* make fixup
parent f56174ac
...@@ -171,4 +171,13 @@ def get_keys_to_not_convert(model): ...@@ -171,4 +171,13 @@ def get_keys_to_not_convert(model):
intersection = set(list_last_module) - set(tied_keys) intersection = set(list_last_module) - set(tied_keys)
list_untouched = tied_keys + list(intersection) list_untouched = tied_keys + list(intersection)
return [module_name.split(".")[0] for module_name in list_untouched] # remove ".weight" from the keys
names_to_remove = [".weight", ".bias"]
filtered_module_names = []
for name in list_untouched:
for name_to_remove in names_to_remove:
if name_to_remove in name:
name = name.replace(name_to_remove, "")
filtered_module_names.append(name)
return filtered_module_names
...@@ -269,10 +269,16 @@ class MixedInt8T5Test(unittest.TestCase): ...@@ -269,10 +269,16 @@ class MixedInt8T5Test(unittest.TestCase):
`flan-t5-small` uses `T5DenseGatedActDense` whereas `t5-small` uses `T5DenseReluDense`. We need to test `flan-t5-small` uses `T5DenseGatedActDense` whereas `t5-small` uses `T5DenseReluDense`. We need to test
both cases. both cases.
""" """
import bitsandbytes as bnb
from transformers import T5ForConditionalGeneration from transformers import T5ForConditionalGeneration
# test with `t5-small` # test with `t5-small`
model = T5ForConditionalGeneration.from_pretrained(self.model_name, load_in_8bit=True, device_map="auto") model = T5ForConditionalGeneration.from_pretrained(self.model_name, load_in_8bit=True, device_map="auto")
# there was a bug with decoders - this test checks that it is fixed
self.assertTrue(isinstance(model.decoder.block[0].layer[0].SelfAttention.q, bnb.nn.Linear8bitLt))
encoded_input = self.tokenizer(self.input_text, return_tensors="pt").to(0) encoded_input = self.tokenizer(self.input_text, return_tensors="pt").to(0)
_ = model.generate(**encoded_input) _ = model.generate(**encoded_input)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment