Unverified Commit 370f0ca1 authored by Younes Belkada's avatar Younes Belkada Committed by GitHub
Browse files

[`bnb`] Let's make serialization of int8 models possible (#22177)



* make serialization of int8 models possible

* make fixup

* add docs

* add ability to push to hub and save pretrained

* fixes

* more addition

* more tests

* fix issues

* change variable

* clearer message

* adapt from suggestions

* few fixes

* remove unused function

* Update src/transformers/utils/quantization_config.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* address last comments

* last warning

* clarify doc

* protect import

* Update src/transformers/modeling_utils.py

* Apply suggestions from code review
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

---------
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
parent 523ca4e0
...@@ -52,6 +52,37 @@ Note that once a model has been loaded in 8-bit it is currently not possible to ...@@ -52,6 +52,37 @@ Note that once a model has been loaded in 8-bit it is currently not possible to
</Tip> </Tip>
### Push quantized models on the 🤗 Hub
You can push a quantized model on the Hub by naively using `push_to_hub` method. This will first push the quantization configuration file, then push the quantized model weights.
Make sure to use `bitsandbytes>0.37.2` (at this time of writing, we tested it on `bitsandbytes==0.38.0.post1`) to be able to use this feature.
```python
from transformers import AutoModelForCausalLM, AutoTokenizer
model = AutoModelForCausalLM.from_pretrained("bigscience/bloom-560m", device_map="auto", load_in_8bit=True)
tokenizer = AutoTokenizer.from_pretrained("bigscience/bloom-560m")
model.push_to_hub("bloom-560m-8bit")
```
<Tip warning={true}>
Pushing 8bit models on the Hub is strongely encouraged for large models. This will allow the community to benefit from the memory footprint reduction and loading for example large models on a Google Colab.
</Tip>
### Load a quantized model from the 🤗 Hub
You can load a quantized model from the Hub by using `from_pretrained` method. Make sure that the pushed weights are quantized, by checking that the attribute `quantization_config` is present in the model configuration object.
```python
from transformers import AutoModelForCausalLM, AutoTokenizer
model = AutoModelForCausalLM.from_pretrained("{your_username}/bloom-560m-8bit")
```
Note that in this case, you don't need to specify the arguments `load_in_8bit=True` and `device_map="auto"`, but you need to make sure that `bitsandbytes` and `accelerate` are installed.
### Advanced usecases ### Advanced usecases
This section is intended to advanced users, that want to explore what it is possible to do beyond loading and running 8-bit models. This section is intended to advanced users, that want to explore what it is possible to do beyond loading and running 8-bit models.
......
...@@ -801,6 +801,13 @@ class PretrainedConfig(PushToHubMixin): ...@@ -801,6 +801,13 @@ class PretrainedConfig(PushToHubMixin):
# Transformers version when serializing the model # Transformers version when serializing the model
output["transformers_version"] = __version__ output["transformers_version"] = __version__
if hasattr(self, "quantization_config"):
output["quantization_config"] = (
self.quantization_config.to_dict()
if not isinstance(self.quantization_config, dict)
else self.quantization_config
)
self.dict_torch_dtype_to_str(output) self.dict_torch_dtype_to_str(output)
return output return output
......
...@@ -697,7 +697,15 @@ def _load_state_dict_into_meta_model( ...@@ -697,7 +697,15 @@ def _load_state_dict_into_meta_model(
# For backward compatibility with older versions of `accelerate` # For backward compatibility with older versions of `accelerate`
set_module_tensor_to_device(model, param_name, param_device, **set_module_kwargs) set_module_tensor_to_device(model, param_name, param_device, **set_module_kwargs)
else: else:
set_module_8bit_tensor_to_device(model, param_name, param_device, value=param) if param.dtype == torch.int8 and param_name.replace("weight", "SCB") in state_dict.keys():
fp16_statistics = state_dict[param_name.replace("weight", "SCB")]
else:
fp16_statistics = None
if "SCB" not in param_name:
set_module_8bit_tensor_to_device(
model, param_name, param_device, value=param, fp16_statistics=fp16_statistics
)
return error_msgs, offload_index, state_dict_index return error_msgs, offload_index, state_dict_index
...@@ -1700,10 +1708,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -1700,10 +1708,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method. Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
""" """
# Checks if the model has been loaded in 8-bit # Checks if the model has been loaded in 8-bit
if getattr(self, "is_loaded_in_8bit", False): if getattr(self, "is_loaded_in_8bit", False) and getattr(self, "is_8bit_serializable", False):
warnings.warn( warnings.warn(
"You are calling `save_pretrained` to a 8-bit converted model you may likely encounter unexepected" "You are calling `save_pretrained` to a 8-bit converted model you may likely encounter unexepected"
" behaviors. ", " behaviors. If you want to save 8-bit models, make sure to have `bitsandbytes>0.37.2` installed.",
UserWarning, UserWarning,
) )
...@@ -2165,6 +2173,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -2165,6 +2173,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
variant = kwargs.pop("variant", None) variant = kwargs.pop("variant", None)
use_safetensors = kwargs.pop("use_safetensors", None if is_safetensors_available() else False) use_safetensors = kwargs.pop("use_safetensors", None if is_safetensors_available() else False)
if is_bitsandbytes_available():
is_8bit_serializable = version.parse(importlib_metadata.version("bitsandbytes")) > version.parse("0.37.2")
else:
is_8bit_serializable = False
if trust_remote_code is True: if trust_remote_code is True:
logger.warning( logger.warning(
"The argument `trust_remote_code` is to be used with Auto classes. It has no effect here and is" "The argument `trust_remote_code` is to be used with Auto classes. It has no effect here and is"
...@@ -2207,6 +2220,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -2207,6 +2220,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
"`quantization_config` argument at the same time." "`quantization_config` argument at the same time."
) )
# in the case a user loads an 8bit model from the Hub and assigns a new quantization_config
if device_map is None:
device_map = "auto"
if low_cpu_mem_usage is None:
low_cpu_mem_usage = True
if load_in_8bit: if load_in_8bit:
if not (is_accelerate_available() and is_bitsandbytes_available()): if not (is_accelerate_available() and is_bitsandbytes_available()):
raise ImportError( raise ImportError(
...@@ -2265,6 +2284,43 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -2265,6 +2284,43 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
else: else:
model_kwargs = kwargs model_kwargs = kwargs
if is_8bit_serializable and quantization_config is not None and load_in_8bit:
if hasattr(config, "quantization_config"):
logger.warning(
"You passed `quantization_config` to `from_pretrained` but the model you're loading already has a"
" `quantization_config` attribute. The `quantization_config` attribute will be overwritten with the"
" one you passed to `from_pretrained`."
)
config.quantization_config = quantization_config
elif is_8bit_serializable and not load_in_8bit and hasattr(config, "quantization_config"):
quantization_config = config.quantization_config
if isinstance(quantization_config, dict):
quantization_config = BitsAndBytesConfig.from_dict(quantization_config, return_unused_kwargs=False)
elif isinstance(quantization_config, BitsAndBytesConfig):
pass
else:
raise ValueError(
f"Invalid type for `quantization_config`: {type(quantization_config)}. Should be a `dict` or a"
" `BitsAndBytesConfig` instance."
)
load_in_8bit = quantization_config.load_in_8bit
if load_in_8bit:
torch_dtype = torch.float16
if device_map is None:
device_map = "auto"
if low_cpu_mem_usage is None:
low_cpu_mem_usage = True
elif not is_8bit_serializable and not load_in_8bit and hasattr(config, "quantization_config"):
logger.warning(
"Detected the presence of a `quantization_config` attribute in the model's configuration but you don't have the correct"
" `bitsandbytes` version to support int8 serialization. Please install the latest version of `bitsandbytes` with "
" `pip install --upgrade bitsandbytes`."
)
if commit_hash is None: if commit_hash is None:
commit_hash = getattr(config, "_commit_hash", None) commit_hash = getattr(config, "_commit_hash", None)
...@@ -2621,6 +2677,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -2621,6 +2677,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
importlib_metadata.version("bitsandbytes") importlib_metadata.version("bitsandbytes")
) >= version.parse("0.37.0") ) >= version.parse("0.37.0")
model.config.quantization_config = quantization_config
model.is_8bit_serializable = is_8bit_serializable
if isinstance(device_map, str): if isinstance(device_map, str):
special_dtypes = {} special_dtypes = {}
if load_in_8bit: if load_in_8bit:
...@@ -3113,6 +3172,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -3113,6 +3172,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
) )
raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}") raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}")
if load_in_8bit:
unexpected_keys = [elem for elem in unexpected_keys if "SCB" not in elem]
missing_keys = [elem for elem in missing_keys if "SCB" not in elem]
if len(unexpected_keys) > 0: if len(unexpected_keys) > 0:
logger.warning( logger.warning(
f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when" f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when"
......
from copy import deepcopy from copy import deepcopy
from .import_utils import is_accelerate_available, is_bitsandbytes_available from packaging import version
from .import_utils import importlib_metadata, is_accelerate_available, is_bitsandbytes_available
if is_bitsandbytes_available(): if is_bitsandbytes_available():
...@@ -13,7 +15,7 @@ if is_accelerate_available(): ...@@ -13,7 +15,7 @@ if is_accelerate_available():
from accelerate.utils import find_tied_parameters from accelerate.utils import find_tied_parameters
def set_module_8bit_tensor_to_device(module, tensor_name, device, value=None): def set_module_8bit_tensor_to_device(module, tensor_name, device, value=None, fp16_statistics=None):
""" """
A helper function to set a given tensor (parameter of buffer) of a module on a specific device (note that doing A helper function to set a given tensor (parameter of buffer) of a module on a specific device (note that doing
`param.to(device)` creates a new tensor not linked to the parameter, which is why we need this function). The `param.to(device)` creates a new tensor not linked to the parameter, which is why we need this function). The
...@@ -29,6 +31,8 @@ def set_module_8bit_tensor_to_device(module, tensor_name, device, value=None): ...@@ -29,6 +31,8 @@ def set_module_8bit_tensor_to_device(module, tensor_name, device, value=None):
The device on which to set the tensor. The device on which to set the tensor.
value (`torch.Tensor`, *optional*): value (`torch.Tensor`, *optional*):
The value of the tensor (useful when going from the meta device to any other device). The value of the tensor (useful when going from the meta device to any other device).
fp16_statistics (`torch.HalfTensor`, *optional*):
The list of fp16 statistics to set on the module, used for serialization.
""" """
# Recurse if needed # Recurse if needed
if "." in tensor_name: if "." in tensor_name:
...@@ -61,14 +65,21 @@ def set_module_8bit_tensor_to_device(module, tensor_name, device, value=None): ...@@ -61,14 +65,21 @@ def set_module_8bit_tensor_to_device(module, tensor_name, device, value=None):
elif isinstance(value, torch.Tensor): elif isinstance(value, torch.Tensor):
new_value = value.to("cpu") new_value = value.to("cpu")
if value.dtype == torch.int8: if value.dtype == torch.int8:
is_8bit_serializable = version.parse(importlib_metadata.version("bitsandbytes")) > version.parse(
"0.37.2"
)
if not is_8bit_serializable:
raise ValueError( raise ValueError(
"You cannot load weights that are saved in int8 using `load_in_8bit=True`, make sure you are", "Detected int8 weights but the version of bitsandbytes is not compatible with int8 serialization. "
" using `load_in_8bit=True` on float32/float16/bfloat16 weights.", "Make sure to download the latest `bitsandbytes` version. `pip install --upgrade bitsandbytes`."
) )
else: else:
new_value = torch.tensor(value, device="cpu") new_value = torch.tensor(value, device="cpu")
new_value = bnb.nn.Int8Params(new_value, requires_grad=False, has_fp16_weights=has_fp16_weights).to(device) new_value = bnb.nn.Int8Params(new_value, requires_grad=False, has_fp16_weights=has_fp16_weights).to(device)
module._parameters[tensor_name] = new_value module._parameters[tensor_name] = new_value
if fp16_statistics is not None:
setattr(module.weight, "SCB", fp16_statistics.to(device))
else: else:
if value is None: if value is None:
new_value = old_value.to(device) new_value = old_value.to(device)
......
...@@ -14,7 +14,16 @@ ...@@ -14,7 +14,16 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import copy
import json
import os
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Dict, Union
from ..utils import logging
logger = logging.get_logger(__name__)
@dataclass @dataclass
...@@ -49,6 +58,8 @@ class BitsAndBytesConfig: ...@@ -49,6 +58,8 @@ class BitsAndBytesConfig:
your model in different parts and run some parts in int8 on GPU and some parts in fp32 on CPU, you can use your model in different parts and run some parts in int8 on GPU and some parts in fp32 on CPU, you can use
this flag. This is useful for offloading large models such as `google/flan-t5-xxl`. Note that the int8 this flag. This is useful for offloading large models such as `google/flan-t5-xxl`. Note that the int8
operations will not be run on CPU. operations will not be run on CPU.
kwargs (`Dict[str, Any]`, *optional*):
Additional parameters from which to initialize the configuration object.
""" """
def __init__( def __init__(
...@@ -57,6 +68,7 @@ class BitsAndBytesConfig: ...@@ -57,6 +68,7 @@ class BitsAndBytesConfig:
llm_int8_threshold=6.0, llm_int8_threshold=6.0,
llm_int8_skip_modules=None, llm_int8_skip_modules=None,
llm_int8_enable_fp32_cpu_offload=False, llm_int8_enable_fp32_cpu_offload=False,
**kwargs,
): ):
self.load_in_8bit = load_in_8bit self.load_in_8bit = load_in_8bit
self.llm_int8_threshold = llm_int8_threshold self.llm_int8_threshold = llm_int8_threshold
...@@ -81,17 +93,19 @@ class BitsAndBytesConfig: ...@@ -81,17 +93,19 @@ class BitsAndBytesConfig:
@classmethod @classmethod
def from_dict(cls, config_dict, return_unused_kwargs, **kwargs): def from_dict(cls, config_dict, return_unused_kwargs, **kwargs):
""" """
Instantiates a [`PretrainedConfig`] from a Python dictionary of parameters. Instantiates a [`BitsAndBytesConfig`] from a Python dictionary of parameters.
Args: Args:
config_dict (`Dict[str, Any]`): config_dict (`Dict[str, Any]`):
Dictionary that will be used to instantiate the configuration object. Such a dictionary can be Dictionary that will be used to instantiate the configuration object.
retrieved from a pretrained checkpoint by leveraging the [`~PretrainedConfig.get_config_dict`] method. return_unused_kwargs (`bool`):
Whether or not to return a list of unused keyword arguments. Used for `from_pretrained` method in
`PreTrainedModel`.
kwargs (`Dict[str, Any]`): kwargs (`Dict[str, Any]`):
Additional parameters from which to initialize the configuration object. Additional parameters from which to initialize the configuration object.
Returns: Returns:
[`PretrainedConfig`]: The configuration object instantiated from those parameters. [`BitsAndBytesConfig`]: The configuration object instantiated from those parameters.
""" """
config = cls(**config_dict) config = cls(**config_dict)
...@@ -107,3 +121,28 @@ class BitsAndBytesConfig: ...@@ -107,3 +121,28 @@ class BitsAndBytesConfig:
return config, kwargs return config, kwargs
else: else:
return config return config
def to_json_file(self, json_file_path: Union[str, os.PathLike]):
"""
Save this instance to a JSON file.
Args:
json_file_path (`str` or `os.PathLike`):
Path to the JSON file in which this configuration instance's parameters will be saved.
use_diff (`bool`, *optional*, defaults to `True`):
If set to `True`, only the difference between the config instance and the default
`BitsAndBytesConfig()` is serialized to JSON file.
"""
with open(json_file_path, "w", encoding="utf-8") as writer:
config_dict = self.to_dict()
json_string = json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
writer.write(json_string)
def to_dict(self) -> Dict[str, Any]:
"""
Serializes this instance to a Python dictionary. Returns:
`Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance.
"""
output = copy.deepcopy(self.__dict__)
return output
...@@ -19,6 +19,7 @@ import unittest ...@@ -19,6 +19,7 @@ import unittest
from packaging import version from packaging import version
from transformers import ( from transformers import (
AutoConfig,
AutoModel, AutoModel,
AutoModelForCausalLM, AutoModelForCausalLM,
AutoModelForSeq2SeqLM, AutoModelForSeq2SeqLM,
...@@ -150,6 +151,13 @@ class MixedInt8Test(BaseMixedInt8Test): ...@@ -150,6 +151,13 @@ class MixedInt8Test(BaseMixedInt8Test):
self.assertEqual(self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUT) self.assertEqual(self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUT)
def test_warns_save_pretrained(self):
r"""
Test whether trying to save a model after converting it in 8-bit will throw a warning.
"""
with self.assertWarns(UserWarning), tempfile.TemporaryDirectory() as tmpdirname:
self.model_8bit.save_pretrained(tmpdirname)
def test_raise_if_config_and_load_in_8bit(self): def test_raise_if_config_and_load_in_8bit(self):
r""" r"""
Test that loading the model with the config and `load_in_8bit` raises an error Test that loading the model with the config and `load_in_8bit` raises an error
...@@ -165,13 +173,6 @@ class MixedInt8Test(BaseMixedInt8Test): ...@@ -165,13 +173,6 @@ class MixedInt8Test(BaseMixedInt8Test):
llm_int8_enable_fp32_cpu_offload=True, llm_int8_enable_fp32_cpu_offload=True,
) )
def test_warns_save_pretrained(self):
r"""
Test whether trying to save a model after converting it in 8-bit will throw a warning.
"""
with self.assertWarns(UserWarning), tempfile.TemporaryDirectory() as tmpdirname:
self.model_8bit.save_pretrained(tmpdirname)
def test_device_and_dtype_assignment(self): def test_device_and_dtype_assignment(self):
r""" r"""
Test whether trying to cast (or assigning a device to) a model after converting it in 8-bit will throw an error. Test whether trying to cast (or assigning a device to) a model after converting it in 8-bit will throw an error.
...@@ -219,6 +220,77 @@ class MixedInt8Test(BaseMixedInt8Test): ...@@ -219,6 +220,77 @@ class MixedInt8Test(BaseMixedInt8Test):
model = AutoModelForSeq2SeqLM.from_pretrained("t5-small", load_in_8bit=True, device_map="auto") model = AutoModelForSeq2SeqLM.from_pretrained("t5-small", load_in_8bit=True, device_map="auto")
self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wo.weight.dtype == torch.float32) self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wo.weight.dtype == torch.float32)
def test_int8_serialization(self):
r"""
Test whether it is possible to serialize a model in 8-bit.
"""
from bitsandbytes.nn import Int8Params
with tempfile.TemporaryDirectory() as tmpdirname:
self.model_8bit.save_pretrained(tmpdirname)
# check that the file `quantization_config` is present
config = AutoConfig.from_pretrained(tmpdirname)
self.assertTrue(hasattr(config, "quantization_config"))
model_from_saved = AutoModelForCausalLM.from_pretrained(tmpdirname, load_in_8bit=True, device_map="auto")
self.assertTrue(model_from_saved.transformer.h[0].mlp.dense_4h_to_h.weight.__class__ == Int8Params)
self.assertTrue(hasattr(model_from_saved.transformer.h[0].mlp.dense_4h_to_h.weight, "SCB"))
# generate
encoded_input = self.tokenizer(self.input_text, return_tensors="pt")
output_sequences = model_from_saved.generate(input_ids=encoded_input["input_ids"].to(0), max_new_tokens=10)
self.assertEqual(
self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUT
)
def test_int8_serialization_sharded(self):
r"""
Test whether it is possible to serialize a model in 8-bit - sharded version.
"""
from bitsandbytes.nn import Int8Params
with tempfile.TemporaryDirectory() as tmpdirname:
self.model_8bit.save_pretrained(tmpdirname, max_shard_size="200MB")
# check that the file `quantization_config` is present
config = AutoConfig.from_pretrained(tmpdirname)
self.assertTrue(hasattr(config, "quantization_config"))
model_from_saved = AutoModelForCausalLM.from_pretrained(tmpdirname)
self.assertTrue(model_from_saved.transformer.h[0].mlp.dense_4h_to_h.weight.__class__ == Int8Params)
self.assertTrue(hasattr(model_from_saved.transformer.h[0].mlp.dense_4h_to_h.weight, "SCB"))
# generate
encoded_input = self.tokenizer(self.input_text, return_tensors="pt")
output_sequences = model_from_saved.generate(input_ids=encoded_input["input_ids"].to(0), max_new_tokens=10)
self.assertEqual(
self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUT
)
def test_int8_from_pretrained(self):
r"""
Test whether loading a 8bit model from the Hub works as expected
"""
from bitsandbytes.nn import Int8Params
model_id = "ybelkada/bloom-1b7-8bit"
model = AutoModelForCausalLM.from_pretrained(model_id)
self.assertTrue(model.transformer.h[0].mlp.dense_4h_to_h.weight.__class__ == Int8Params)
self.assertTrue(hasattr(model.transformer.h[0].mlp.dense_4h_to_h.weight, "SCB"))
# generate
encoded_input = self.tokenizer(self.input_text, return_tensors="pt")
output_sequences = model.generate(input_ids=encoded_input["input_ids"].to(0), max_new_tokens=10)
self.assertEqual(self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUT)
@require_bitsandbytes @require_bitsandbytes
@require_accelerate @require_accelerate
...@@ -289,6 +361,38 @@ class MixedInt8T5Test(unittest.TestCase): ...@@ -289,6 +361,38 @@ class MixedInt8T5Test(unittest.TestCase):
encoded_input = self.tokenizer(self.input_text, return_tensors="pt").to(0) encoded_input = self.tokenizer(self.input_text, return_tensors="pt").to(0)
_ = model.generate(**encoded_input) _ = model.generate(**encoded_input)
def test_inference_with_keep_in_fp32_serialized(self):
r"""
Test whether it is possible to mix both `int8` and `fp32` weights when using `keep_in_fp32_modules` correctly on
a serialized model.
`flan-t5-small` uses `T5DenseGatedActDense` whereas `t5-small` uses `T5DenseReluDense`. We need to test
both cases.
"""
import bitsandbytes as bnb
from transformers import T5ForConditionalGeneration
# test with `t5-small`
model = T5ForConditionalGeneration.from_pretrained(self.model_name, load_in_8bit=True, device_map="auto")
with tempfile.TemporaryDirectory() as tmp_dir:
model.save_pretrained(tmp_dir)
model = T5ForConditionalGeneration.from_pretrained(tmp_dir)
# there was a bug with decoders - this test checks that it is fixed
self.assertTrue(isinstance(model.decoder.block[0].layer[0].SelfAttention.q, bnb.nn.Linear8bitLt))
encoded_input = self.tokenizer(self.input_text, return_tensors="pt").to(0)
_ = model.generate(**encoded_input)
# test with `flan-t5-small`
model = T5ForConditionalGeneration.from_pretrained(
self.dense_act_model_name, load_in_8bit=True, device_map="auto"
)
encoded_input = self.tokenizer(self.input_text, return_tensors="pt").to(0)
_ = model.generate(**encoded_input)
class MixedInt8ModelClassesTest(BaseMixedInt8Test): class MixedInt8ModelClassesTest(BaseMixedInt8Test):
def setUp(self): def setUp(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment