"tests/models/bart/test_modeling_bart.py" did not exist on "ca6b80cadbde650879028dcd733b6a7e8dd56760"
Unverified Commit a5ca56ff authored by Younes Belkada's avatar Younes Belkada Committed by GitHub
Browse files

Supporting seq2seq models for `bitsandbytes` integration (#18579)

* Supporting seq2seq models for `bitsandbytes` integration

- `bitsandbytes` integration supports now seq2seq models
- check if a model has tied weights as an additional check

* small modification

- tie the weights before looking at tied weights!
parent ed1924e8
from copy import deepcopy
from transformers.utils import is_accelerate_available, is_bitsandbytes_available from transformers.utils import is_accelerate_available, is_bitsandbytes_available
...@@ -9,6 +11,7 @@ if is_bitsandbytes_available(): ...@@ -9,6 +11,7 @@ if is_bitsandbytes_available():
if is_accelerate_available(): if is_accelerate_available():
from accelerate import init_empty_weights from accelerate import init_empty_weights
from accelerate.utils import find_tied_parameters
def set_module_8bit_tensor_to_device(module, tensor_name, device, value=None): def set_module_8bit_tensor_to_device(module, tensor_name, device, value=None):
...@@ -132,8 +135,17 @@ def get_key_to_not_convert(model): ...@@ -132,8 +135,17 @@ def get_key_to_not_convert(model):
model (`torch.nn.Module`): model (`torch.nn.Module`):
Input model Input model
""" """
# Create a copy of the model and tie the weights, then
# check if it contains tied weights
tied_model = deepcopy(model) # this has 0 cost since it is done inside `init_empty_weights` context manager`
tied_model.tie_weights()
has_tied_params = len(find_tied_parameters(tied_model)) > 0
# Check if it is a base model
is_base_model = not hasattr(model, model.base_model_prefix)
# Ignore this for base models (BertModel, GPT2Model, etc.) # Ignore this for base models (BertModel, GPT2Model, etc.)
if not hasattr(model, model.base_model_prefix): if (not has_tied_params) and is_base_model:
return "" return ""
# otherwise they have an attached head # otherwise they have an attached head
......
...@@ -15,7 +15,14 @@ ...@@ -15,7 +15,14 @@
import gc import gc
import unittest import unittest
from transformers import AutoModel, AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer, pipeline from transformers import (
AutoModel,
AutoModelForCausalLM,
AutoModelForSeq2SeqLM,
AutoModelForSequenceClassification,
AutoTokenizer,
pipeline,
)
from transformers.testing_utils import ( from transformers.testing_utils import (
is_torch_available, is_torch_available,
require_accelerate, require_accelerate,
...@@ -106,12 +113,21 @@ class MixedInt8ModelClassesTest(BaseMixedInt8Test): ...@@ -106,12 +113,21 @@ class MixedInt8ModelClassesTest(BaseMixedInt8Test):
super().setUp() super().setUp()
# model_name # model_name
self.model_name = "bigscience/bloom-560m" self.model_name = "bigscience/bloom-560m"
# Models and tokenizer self.seq_to_seq_name = "t5-small"
# Different types of model
self.base_model = AutoModel.from_pretrained(self.model_name, load_in_8bit=True, device_map="auto") self.base_model = AutoModel.from_pretrained(self.model_name, load_in_8bit=True, device_map="auto")
# Sequence classification model
self.sequence_model = AutoModelForSequenceClassification.from_pretrained( self.sequence_model = AutoModelForSequenceClassification.from_pretrained(
self.model_name, load_in_8bit=True, device_map="auto" self.model_name, load_in_8bit=True, device_map="auto"
) )
# CausalLM model
self.model_8bit = AutoModelForCausalLM.from_pretrained(self.model_name, load_in_8bit=True, device_map="auto") self.model_8bit = AutoModelForCausalLM.from_pretrained(self.model_name, load_in_8bit=True, device_map="auto")
# Seq2seq model
self.seq_to_seq_model = AutoModelForSeq2SeqLM.from_pretrained(
self.seq_to_seq_name, load_in_8bit=True, device_map="auto"
)
def tearDown(self): def tearDown(self):
r""" r"""
...@@ -121,6 +137,7 @@ class MixedInt8ModelClassesTest(BaseMixedInt8Test): ...@@ -121,6 +137,7 @@ class MixedInt8ModelClassesTest(BaseMixedInt8Test):
del self.base_model del self.base_model
del self.sequence_model del self.sequence_model
del self.model_8bit del self.model_8bit
del self.seq_to_seq_model
gc.collect() gc.collect()
torch.cuda.empty_cache() torch.cuda.empty_cache()
...@@ -138,6 +155,7 @@ class MixedInt8ModelClassesTest(BaseMixedInt8Test): ...@@ -138,6 +155,7 @@ class MixedInt8ModelClassesTest(BaseMixedInt8Test):
# Other heads should be nn.Parameter # Other heads should be nn.Parameter
self.assertTrue(self.model_8bit.lm_head.weight.__class__ == torch.nn.Parameter) self.assertTrue(self.model_8bit.lm_head.weight.__class__ == torch.nn.Parameter)
self.assertTrue(self.sequence_model.score.weight.__class__ == torch.nn.Parameter) self.assertTrue(self.sequence_model.score.weight.__class__ == torch.nn.Parameter)
self.assertTrue(self.seq_to_seq_model.lm_head.weight.__class__ == torch.nn.Parameter)
class MixedInt8TestPipeline(BaseMixedInt8Test): class MixedInt8TestPipeline(BaseMixedInt8Test):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment