Unverified Commit 3668ec17 authored by Younes Belkada's avatar Younes Belkada Committed by GitHub
Browse files

[`bnb`] Introducing `BitsAndBytesConfig` (#21579)



* v1 `BitsandbytesConfig`

- add v1
- add tests
- more user-friendly API
- add docs

* change to `BitsAndBytesConfig`

* replace logic

* changes

* make fixup

* quality

* make fixup

* fix doc

* fix test

* update toctree

* fix slow test

* add tips

* add warning

* change title

* oops

* Update docs/source/en/main_classes/quantization.mdx
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update src/transformers/utils/bitsandbytes.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* remove unused file

* adapt suggestion

- add also tests
- change logic

* update docs

* adapt suggestions

---------
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
parent f16d29b3
......@@ -199,6 +199,8 @@
title: Pipelines
- local: main_classes/processors
title: Processors
- local: main_classes/quantization
title: Quantization
- local: main_classes/tokenizer
title: Tokenizer
- local: main_classes/trainer
......
<!--Copyright 2023 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# Quantize 🤗 Transformers models
## `bitsandbytes` Integration
🤗 Transformers is closely integrated with most used modules on `bitsandbytes`. You can load your model in 8-bit precision with few lines of code.
This is supported by most of the GPU hardwares since the `0.37.0` release of `bitsandbytes`.
Learn more about the quantization method in the [LLM.int8()](https://arxiv.org/abs/2208.07339) paper, or the [blogpost](https://huggingface.co/blog/hf-bitsandbytes-integration) about the collaboration.
Here are the things you can do using `bitsandbytes` integration
### Load a large model in 8bit
You can load a model by roughly halving the memory requirements by using `load_in_8bit=True` argument when calling `.from_pretrained` method
```python
# pip install transformers accelerate bitsandbytes
from transformers import AutoModelForCausalLM, AutoTokenizer
model_id = "bigscience/bloom-1b7"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id, device_map == "auto", load_in_8bit=True)
```
Then, use your model as you would usually use a [`PreTrainedModel`].
You can check the memory footprint of your model with `get_memory_footprint` method.
```python
print(model.get_memory_footprint())
```
With this integration we were able to load large models on smaller devices and run them without any issue.
<Tip warning={true}>
Note that once a model has been loaded in 8-bit it is currently not possible to push the quantized weights on the Hub. Note also that you cannot train 8-bit weights as this is not supported yet. However you can use 8-bit models to train extra parameters, this will be covered in the next section.
</Tip>
### Advanced usecases
This section is intended to advanced users, that want to explore what it is possible to do beyond loading and running 8-bit models.
#### Offload between `cpu` and `gpu`
One of the advanced usecase of this is being able to load a model and dispatch the weights between `CPU` and `GPU`. Note that the weights that will be dispatched on CPU **will not** be converted in 8-bit, thus kept in `float32`. This feature is intended for users that want to fit a very large model and dispatch the model between GPU and CPU.
First, load a `BitsAndBytesConfig` from `transformers` and set the attribute `llm_int8_enable_fp32_cpu_offload` to `True`:
```python
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
quantization_config = BitsAndBytesConfig(llm_int8_enable_fp32_cpu_offload=True)
```
Let's say you want to load `bigscience/bloom-1b7` model, and you have just enough GPU RAM to fit the entire model except the `lm_head`. Therefore write a custom device_map as follows:
```python
device_map = {
"transformer.word_embeddings": 0,
"transformer.word_embeddings_layernorm": 0,
"lm_head": "cpu",
"transformer.h": 0,
"transformer.ln_f": 0,
}
```
And load your model as follows:
```python
model_8bit = AutoModelForCausalLM.from_pretrained(
"bigscience/bloom-1b7",
device_map=device_map,
quantization_config=quantization_config,
)
```
And that's it! Enjoy your model!
#### Play with `llm_int8_threshold`
You can play with the `llm_int8_threshold` argument to change the threshold of the outliers. An "outlier" is a hidden state value that is greater than a certain threshold.
This corresponds to the outlier threshold for outlier detection as described in `LLM.int8()` paper. Any hidden states value that is above this threshold will be considered an outlier and the operation on those values will be done in fp16. Values are usually normally distributed, that is, most values are in the range [-3.5, 3.5], but there are some exceptional systematic outliers that are very differently distributed for large models. These outliers are often in the interval [-60, -6] or [6, 60]. Int8 quantization works well for values of magnitude ~5, but beyond that, there is a significant performance penalty. A good default threshold is 6, but a lower threshold might be needed for more unstable models (small models, fine-tuning).
This argument can impact the inference speed of the model. We suggest to play with this parameter to find which one is the best for your usecase.
```python
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
model_id = "bigscience/bloom-1b7"
quantization_config = BitsAndBytesConfig(
llm_int8_threshold=10,
)
model_8bit = AutoModelForCausalLM.from_pretrained(
model_id,
device_map=device_map,
quantization_config=quantization_config,
)
tokenizer = AutoTokenizer.from_pretrained(model_id)
```
#### Skip the conversion of some modules
Some models has several modules that needs to be not converted in 8-bit to ensure stability. For example Jukebox model has several `lm_head` modules that should be skipped. Play with `llm_int8_skip_modules`
```python
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
model_id = "bigscience/bloom-1b7"
quantization_config = BitsAndBytesConfig(
llm_int8_skip_modules=["lm_head"],
)
model_8bit = AutoModelForCausalLM.from_pretrained(
model_id,
device_map=device_map,
quantization_config=quantization_config,
)
tokenizer = AutoTokenizer.from_pretrained(model_id)
```
#### Fine-tune a model that has been loaded in 8-bit
With the official support of adapters in the Hugging Face ecosystem, you can fine-tune models that have been loaded in 8-bit.
This enables fine-tuning large models such as `flan-t5-large` or `facebook/opt-6.7b` in a single google Colab. Please have a look at [`peft`](https://github.com/huggingface/peft) library for more details.
### BitsAndBytesConfig
[[autodoc]] BitsAndBytesConfig
## Quantization with 🤗 `optimum`
Please have a look at [Optimum documentation](https://huggingface.co/docs/optimum/index) to learn more about quantization methods that are supported by `optimum` and see if these are applicable for your usecase.
......@@ -27,6 +27,7 @@ from . import dependency_versions_check
from .utils import (
OptionalDependencyNotAvailable,
_LazyModule,
is_bitsandbytes_available,
is_flax_available,
is_keras_nlp_available,
is_sentencepiece_available,
......@@ -596,6 +597,7 @@ _import_structure = {
"add_end_docstrings",
"add_start_docstrings",
"is_apex_available",
"is_bitsandbytes_available",
"is_datasets_available",
"is_decord_available",
"is_faiss_available",
......@@ -622,6 +624,7 @@ _import_structure = {
"logging",
],
"utils.bitsandbytes": [],
"utils.quantization_config": ["BitsAndBytesConfig"],
}
# sentencepiece-backed objects
......@@ -4114,6 +4117,7 @@ if TYPE_CHECKING:
add_end_docstrings,
add_start_docstrings,
is_apex_available,
is_bitsandbytes_available,
is_datasets_available,
is_decord_available,
is_faiss_available,
......@@ -4140,6 +4144,9 @@ if TYPE_CHECKING:
logging,
)
# bitsandbytes config
from .utils.quantization_config import BitsAndBytesConfig
try:
if not is_sentencepiece_available():
raise OptionalDependencyNotAvailable()
......@@ -6535,6 +6542,7 @@ if TYPE_CHECKING:
FlaxXLMRobertaPreTrainedModel,
)
else:
import sys
......
......@@ -74,6 +74,7 @@ from .utils import (
replace_return_docstrings,
)
from .utils.import_utils import importlib_metadata
from .utils.quantization_config import BitsAndBytesConfig
from .utils.versions import require_version_core
......@@ -1992,19 +1993,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
https://test.pypi.org/simple/ bitsandbytes-cudaXXX` where XXX is your CUDA version (e.g. 11.6 = 116).
Make also sure that you have enough GPU RAM to store half of the model size since the 8bit modules are
not compiled and adapted for CPUs.
load_in_8bit_threshold (`float`, *optional*, defaults to 6):
Works together with `load_in_8bit`. This corresponds to the outlier threshold for outlier detection as
described in `GPT3.int8() : 8-bit Matrix Multiplication for Transformers at Scale` paper. Any hidden
states value that is above this threshold will be considered an outlier and the operation on those
values will be done in fp16. Values are usually normally distributed, that is, most values are in the
range [-3.5, 3.5], but there are some exceptional systematic outliers that are very differently
distributed for large models. These outliers are often in the interval [-60, -6] or [6, 60]. Int8
quantization works well for values of magnitude ~5, but beyond that, there is a significant performance
penalty. A good default threshold is 6, but a lower threshold might be needed for more unstable models
(small models, fine-tuning).
load_in_8bit_skip_modules (`List[str]`, *optional*):
An explicit list of the modules that we do not want to convert in 8-bit. This is useful for models such
as Jukebox that has several heads in different places and not necessarily at the last position.
quantization_config (`Dict`, *optional*):
A dictionary of configuration parameters for the `bitsandbytes` library and loading the model using
advanced features such as offloading in fp32 on CPU or on disk.
subfolder (`str`, *optional*, defaults to `""`):
In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can
specify the folder name here.
......@@ -2093,8 +2084,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
offload_folder = kwargs.pop("offload_folder", None)
offload_state_dict = kwargs.pop("offload_state_dict", False)
load_in_8bit = kwargs.pop("load_in_8bit", False)
load_in_8bit_threshold = kwargs.pop("load_in_8bit_threshold", 6.0)
load_in_8bit_skip_modules = kwargs.pop("load_in_8bit_skip_modules", None)
quantization_config = kwargs.pop("quantization_config", None)
subfolder = kwargs.pop("subfolder", "")
commit_hash = kwargs.pop("_commit_hash", None)
variant = kwargs.pop("variant", None)
......@@ -2126,6 +2116,23 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
"Using `low_cpu_mem_usage=True` or a `device_map` requires Accelerate: `pip install accelerate`"
)
if quantization_config is None:
quantization_config, kwargs = BitsAndBytesConfig.from_dict(
config_dict={"load_in_8bit": load_in_8bit}, return_unused_kwargs=True, **kwargs
)
elif quantization_config is not None:
load_in_8bit = quantization_config.load_in_8bit
quantization_config_kwargs = {
k: v for k, v in kwargs.items() if k in inspect.signature(BitsAndBytesConfig).parameters
}
if len(quantization_config_kwargs) > 0:
raise ValueError(
"You can't pass `load_in_8bit` or any other `BitsAndBytesConfig` argument as a kwarg when passing "
"`quantization_config` argument at the same time."
)
if load_in_8bit:
if not (is_accelerate_available() and is_bitsandbytes_available()):
raise ImportError(
......@@ -2497,6 +2504,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
if load_in_8bit:
from .utils.bitsandbytes import get_keys_to_not_convert, replace_8bit_linear
load_in_8bit_skip_modules = quantization_config.llm_int8_skip_modules
load_in_8bit_threshold = quantization_config.llm_int8_threshold
load_in_8bit_fp32_cpu_offload = quantization_config.llm_int8_enable_fp32_cpu_offload
logger.info("Detected 8-bit loading: activating 8-bit loading for this model")
# We keep some modules such as the lm_head in their original dtype for numerical stability reasons
......@@ -2510,6 +2521,19 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
modules_to_not_convert.extend(keep_in_fp32_modules)
# Extend the modules to not convert to keys that are supposed to be offloaded to `cpu` or `disk`
if isinstance(device_map, dict) and len(device_map.keys()) > 1:
keys_on_cpu = [key for key, value in device_map.items() if value in ["disk", "cpu"]]
if len(keys_on_cpu) > 0 and not load_in_8bit_fp32_cpu_offload:
raise ValueError(
"If you want to offload some keys to `cpu` or `disk`, you need to set "
"`load_in_8bit_fp32_cpu_offload=True`. Note that these modules will not be "
" converted to 8-bit but kept in 32-bit."
)
modules_to_not_convert.extend(keys_on_cpu)
model = replace_8bit_linear(
model, threshold=load_in_8bit_threshold, modules_to_not_convert=modules_to_not_convert
)
......
......@@ -84,7 +84,7 @@ def set_module_8bit_tensor_to_device(module, tensor_name, device, value=None):
module._parameters[tensor_name] = new_value
def replace_8bit_linear(model, threshold=6.0, modules_to_not_convert="lm_head"):
def replace_8bit_linear(model, threshold=6.0, modules_to_not_convert="lm_head", current_key_name=None):
"""
A helper function to replace all `torch.nn.Linear` modules by `bnb.nn.Linear8bit` modules from the `bitsandbytes`
library. This will enable running your models using mixed int8 precision as described by the paper `GPT3.int8():
......@@ -108,12 +108,22 @@ def replace_8bit_linear(model, threshold=6.0, modules_to_not_convert="lm_head"):
modules_to_not_convert (`str`, *optional*, defaults to `lm_head`):
Name of the module to not convert in `Linear8bitLt`. In practice we keep the `lm_head` in full precision
for numerical stability reasons.
current_key_name (`List[`str`]`, *optional*):
An array to track the current key of the recursion. This is used to check whether the current key (part of
it) is not in the list of modules to not convert (for instances modules that are offloaded to `cpu` or
`disk`).
"""
for name, module in model.named_children():
if current_key_name is None:
current_key_name = []
current_key_name.append(name)
if len(list(module.children())) > 0:
replace_8bit_linear(module, threshold, modules_to_not_convert)
replace_8bit_linear(module, threshold, modules_to_not_convert, current_key_name)
if isinstance(module, nn.Linear) and name not in modules_to_not_convert:
# Check if the current key is not in the `modules_to_not_convert`
if not any(key in ".".join(current_key_name) for key in modules_to_not_convert):
with init_empty_weights():
model._modules[name] = bnb.nn.Linear8bitLt(
module.in_features,
......@@ -122,6 +132,8 @@ def replace_8bit_linear(model, threshold=6.0, modules_to_not_convert="lm_head"):
has_fp16_weights=False,
threshold=threshold,
)
# Remove the last key for recursion
current_key_name.pop(-1)
return model
......
#!/usr/bin/env python
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
@dataclass
class BitsAndBytesConfig:
"""
This is a wrapper class about all possible attributes and features that you can play with a model that has been
loaded using `bitsandbytes`.
This replaces `load_in_8bit` therefore both options are mutually exclusive.
For now, only arguments that are relative to `LLM.int8()` are supported, therefore the arguments are all termed as
`llm_int8_*`. If more methods are added to `bitsandbytes`, then more arguments will be added to this class.
Args:
load_in_8bit (`bool`, *optional*, defaults to `False`):
This flag is used to enable 8-bit quantization with LLM.int8().
llm_int8_threshold (`float`, *optional*, defaults to 6):
This corresponds to the outlier threshold for outlier detection as described in `LLM.int8() : 8-bit Matrix
Multiplication for Transformers at Scale` paper: https://arxiv.org/abs/2208.07339 Any hidden states value
that is above this threshold will be considered an outlier and the operation on those values will be done
in fp16. Values are usually normally distributed, that is, most values are in the range [-3.5, 3.5], but
there are some exceptional systematic outliers that are very differently distributed for large models.
These outliers are often in the interval [-60, -6] or [6, 60]. Int8 quantization works well for values of
magnitude ~5, but beyond that, there is a significant performance penalty. A good default threshold is 6,
but a lower threshold might be needed for more unstable models (small models, fine-tuning).
llm_int8_skip_modules (`List[str]`, *optional*):
An explicit list of the modules that we do not want to convert in 8-bit. This is useful for models such as
Jukebox that has several heads in different places and not necessarily at the last position. For example
for `CausalLM` models, the last `lm_head` is kept in its original `dtype`.
llm_int8_enable_fp32_cpu_offload (`bool`, *optional*, defaults to `False`):
This flag is used for advanced use cases and users that are aware of this feature. If you want to split
your model in different parts and run some parts in int8 on GPU and some parts in fp32 on CPU, you can use
this flag. This is useful for offloading large models such as `google/flan-t5-xxl`. Note that the int8
operations will not be run on CPU.
"""
def __init__(
self,
load_in_8bit=False,
llm_int8_threshold=6.0,
llm_int8_skip_modules=None,
llm_int8_enable_fp32_cpu_offload=False,
):
self.load_in_8bit = load_in_8bit
self.llm_int8_threshold = llm_int8_threshold
self.llm_int8_skip_modules = llm_int8_skip_modules
self.llm_int8_enable_fp32_cpu_offload = llm_int8_enable_fp32_cpu_offload
self.post_init()
def post_init(self):
r"""
Safety checker that arguments are correct - also replaces some NoneType arguments with their default values.
"""
if not isinstance(self.llm_int8_threshold, float):
raise ValueError("llm_int8_threshold must be a float")
if self.llm_int8_skip_modules is not None and not isinstance(self.llm_int8_skip_modules, list):
raise ValueError("llm_int8_skip_modules must be a list of strings")
if not isinstance(self.llm_int8_enable_fp32_cpu_offload, bool):
raise ValueError("llm_int8_enable_fp32_cpu_offload must be a boolean")
@classmethod
def from_dict(cls, config_dict, return_unused_kwargs, **kwargs):
"""
Instantiates a [`PretrainedConfig`] from a Python dictionary of parameters.
Args:
config_dict (`Dict[str, Any]`):
Dictionary that will be used to instantiate the configuration object. Such a dictionary can be
retrieved from a pretrained checkpoint by leveraging the [`~PretrainedConfig.get_config_dict`] method.
kwargs (`Dict[str, Any]`):
Additional parameters from which to initialize the configuration object.
Returns:
[`PretrainedConfig`]: The configuration object instantiated from those parameters.
"""
config = cls(**config_dict)
to_remove = []
for key, value in kwargs.items():
if hasattr(config, key):
setattr(config, key, value)
to_remove.append(key)
for key in to_remove:
kwargs.pop(key, None)
if return_unused_kwargs:
return config, kwargs
else:
return config
......@@ -24,6 +24,7 @@ from transformers import (
AutoModelForSeq2SeqLM,
AutoModelForSequenceClassification,
AutoTokenizer,
BitsAndBytesConfig,
pipeline,
)
from transformers.testing_utils import (
......@@ -132,6 +133,38 @@ class MixedInt8Test(BaseMixedInt8Test):
self.assertEqual(self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUT)
def test_generate_quality_config(self):
r"""
Test that loading the model with the config is equivalent
"""
bnb_config = BitsAndBytesConfig()
model_8bit_from_config = AutoModelForCausalLM.from_pretrained(
self.model_name, quantization_config=bnb_config, device_map="auto"
)
encoded_input = self.tokenizer(self.input_text, return_tensors="pt")
output_sequences = model_8bit_from_config.generate(
input_ids=encoded_input["input_ids"].to(0), max_new_tokens=10
)
self.assertEqual(self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUT)
def test_raise_if_config_and_load_in_8bit(self):
r"""
Test that loading the model with the config and `load_in_8bit` raises an error
"""
bnb_config = BitsAndBytesConfig()
with self.assertRaises(ValueError):
_ = AutoModelForCausalLM.from_pretrained(
self.model_name,
quantization_config=bnb_config,
load_in_8bit=True,
device_map="auto",
llm_int8_enable_fp32_cpu_offload=True,
)
def test_warns_save_pretrained(self):
r"""
Test whether trying to save a model after converting it in 8-bit will throw a warning.
......@@ -361,6 +394,151 @@ class MixedInt8TestMultiGpu(BaseMixedInt8Test):
self.assertEqual(self.tokenizer.decode(output_parallel[0], skip_special_tokens=True), self.EXPECTED_OUTPUT)
@require_torch_multi_gpu
class MixedInt8TestCpuGpu(BaseMixedInt8Test):
def setUp(self):
super().setUp()
def check_inference_correctness(self, model):
# Check that inference pass works on the model
encoded_input = self.tokenizer(self.input_text, return_tensors="pt")
# Check the exactness of the results
output_parallel = model.generate(input_ids=encoded_input["input_ids"].to(0), max_new_tokens=10)
# Get the generation
output_text = self.tokenizer.decode(output_parallel[0], skip_special_tokens=True)
self.assertEqual(output_text, self.EXPECTED_OUTPUT)
def test_cpu_gpu_loading_random_device_map(self):
r"""
A test to check is dispatching a model on cpu & gpu works correctly using a random `device_map`.
"""
device_map = {
"transformer.word_embeddings": 0,
"transformer.word_embeddings_layernorm": 0,
"lm_head": 0,
"transformer.h.0": "cpu",
"transformer.h.1": "cpu",
"transformer.h.2": 0,
"transformer.h.3": 0,
"transformer.h.4": 0,
"transformer.h.5": 0,
"transformer.h.6": 0,
"transformer.h.7": 0,
"transformer.h.8": 0,
"transformer.h.9": 1,
"transformer.h.10": 0,
"transformer.h.11": 1,
"transformer.h.12": 0,
"transformer.h.13": 0,
"transformer.h.14": 1,
"transformer.h.15": 0,
"transformer.h.16": 0,
"transformer.h.17": 1,
"transformer.h.18": 1,
"transformer.h.19": 0,
"transformer.h.20": 1,
"transformer.h.21": 1,
"transformer.h.22": 0,
"transformer.h.23": 0,
"transformer.ln_f": 1,
}
bnb_config = BitsAndBytesConfig(llm_int8_enable_fp32_cpu_offload=True)
model_8bit = AutoModelForCausalLM.from_pretrained(
self.model_name,
device_map=device_map,
quantization_config=bnb_config,
)
# Check that the model has been correctly set on device 0, 1, and `cpu`.
self.assertEqual(set(model_8bit.hf_device_map.values()), {0, 1, "cpu"})
self.check_inference_correctness(model_8bit)
def test_cpu_gpu_loading_custom_device_map(self):
r"""
A test to check is dispatching a model on cpu & gpu works correctly using a custom `device_map`.
This time the device map is more organized than the test above and uses the abstraction
`transformer.h` to encapsulate all the decoder layers.
"""
device_map = {
"transformer.word_embeddings": "cpu",
"transformer.word_embeddings_layernorm": "cpu",
"lm_head": "cpu",
"transformer.h": 0,
"transformer.ln_f": 1,
}
bnb_config = BitsAndBytesConfig(llm_int8_enable_fp32_cpu_offload=True)
# Load model
model_8bit = AutoModelForCausalLM.from_pretrained(
self.model_name,
device_map=device_map,
quantization_config=bnb_config,
)
# Check that the model has been correctly set on device 0, 1, and `cpu`.
self.assertEqual(set(model_8bit.hf_device_map.values()), {0, 1, "cpu"})
self.check_inference_correctness(model_8bit)
def test_cpu_gpu_disk_loading_custom_device_map(self):
r"""
A test to check is dispatching a model on cpu & gpu works correctly using a custom `device_map`.
This time we also add `disk` on the device_map.
"""
device_map = {
"transformer.word_embeddings": 0,
"transformer.word_embeddings_layernorm": "cpu",
"lm_head": 0,
"transformer.h": 1,
"transformer.ln_f": "disk",
}
bnb_config = BitsAndBytesConfig(llm_int8_enable_fp32_cpu_offload=True)
with tempfile.TemporaryDirectory() as tmpdirname:
# Load model
model_8bit = AutoModelForCausalLM.from_pretrained(
self.model_name,
device_map=device_map,
quantization_config=bnb_config,
offload_folder=tmpdirname,
)
# Check that the model has been correctly set on device 0, 1, and `cpu`.
self.assertEqual(set(model_8bit.hf_device_map.values()), {0, 1, "cpu", "disk"})
self.check_inference_correctness(model_8bit)
def test_cpu_gpu_disk_loading_custom_device_map_kwargs(self):
r"""
A test to check is dispatching a model on cpu & gpu works correctly using a custom `device_map`.
This time we also add `disk` on the device_map - using the kwargs directly instead of the quantization config
"""
device_map = {
"transformer.word_embeddings": 0,
"transformer.word_embeddings_layernorm": "cpu",
"lm_head": 0,
"transformer.h": 1,
"transformer.ln_f": "disk",
}
with tempfile.TemporaryDirectory() as tmpdirname:
# Load model
model_8bit = AutoModelForCausalLM.from_pretrained(
self.model_name,
device_map=device_map,
llm_int8_enable_fp32_cpu_offload=True,
offload_folder=tmpdirname,
)
# Check that the model has been correctly set on device 0, 1, and `cpu`.
self.assertEqual(set(model_8bit.hf_device_map.values()), {0, 1, "cpu", "disk"})
self.check_inference_correctness(model_8bit)
class MixedInt8TestTraining(BaseMixedInt8Test):
def setUp(self):
self.model_name = "facebook/opt-350m"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment