"docs/vscode:/vscode.git/clone" did not exist on "e0921c6b53310a47b10f01633809b2b9f785a465"
Unverified Commit e42869b0 authored by Younes Belkada's avatar Younes Belkada Committed by GitHub
Browse files

[`bnb`] add warning when no linear (#23894)

* add warning for gpt2-like models

* more details

* adapt from suggestions
parent 8f915c45
...@@ -2687,7 +2687,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -2687,7 +2687,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
model = replace_with_bnb_linear( model = replace_with_bnb_linear(
model, modules_to_not_convert=modules_to_not_convert, quantization_config=quantization_config model, modules_to_not_convert=modules_to_not_convert, quantization_config=quantization_config
) )
# training in 8-bit is only available in 0.37.0+ # training in 8-bit is only available in 0.37.0+
model._is_quantized_training_enabled = version.parse( model._is_quantized_training_enabled = version.parse(
importlib_metadata.version("bitsandbytes") importlib_metadata.version("bitsandbytes")
...@@ -2699,8 +2698,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -2699,8 +2698,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
if load_in_8bit and torch_dtype is None: if load_in_8bit and torch_dtype is None:
logger.warning( logger.warning(
"You are loading your model in 8bit but you did not specify a `torch_dtype` attribute." "You are loading your model in 8bit but you did not specify a `torch_dtype` attribute."
"All non-linear modules will be loaded in full precision.", "All non-linear modules will be loaded in full precision."
" If you want to load the other modules in other precision, please specify a `torch_dtype` attribute.", " If you want to load the other modules in other precision, please specify a `torch_dtype` attribute."
) )
if isinstance(device_map, str): if isinstance(device_map, str):
......
...@@ -3,6 +3,7 @@ from copy import deepcopy ...@@ -3,6 +3,7 @@ from copy import deepcopy
from packaging import version from packaging import version
from ..utils import logging
from .import_utils import importlib_metadata, is_accelerate_available, is_bitsandbytes_available from .import_utils import importlib_metadata, is_accelerate_available, is_bitsandbytes_available
...@@ -15,6 +16,8 @@ if is_accelerate_available(): ...@@ -15,6 +16,8 @@ if is_accelerate_available():
from accelerate import init_empty_weights from accelerate import init_empty_weights
from accelerate.utils import find_tied_parameters from accelerate.utils import find_tied_parameters
logger = logging.get_logger(__name__)
def set_module_quantized_tensor_to_device(module, tensor_name, device, value=None, fp16_statistics=None): def set_module_quantized_tensor_to_device(module, tensor_name, device, value=None, fp16_statistics=None):
""" """
...@@ -106,33 +109,13 @@ def set_module_quantized_tensor_to_device(module, tensor_name, device, value=Non ...@@ -106,33 +109,13 @@ def set_module_quantized_tensor_to_device(module, tensor_name, device, value=Non
module._parameters[tensor_name] = new_value module._parameters[tensor_name] = new_value
def replace_with_bnb_linear(model, modules_to_not_convert=None, current_key_name=None, quantization_config=None): def _replace_with_bnb_linear(model, modules_to_not_convert=None, current_key_name=None, quantization_config=None):
""" """
A helper function to replace all `torch.nn.Linear` modules by `bnb.nn.Linear8bit` modules from the `bitsandbytes` Private method that wraps the recursion for module replacement.
library. This will enable running your models using mixed int8 precision as described by the paper `LLM.int8():
8-bit Matrix Multiplication for Transformers at Scale`. Make sure `bitsandbytes` compiled with the correct CUDA
version of your hardware is installed before running this function. `pip install -i https://test.pypi.org/simple/
bitsandbytes`
The function will be run recursively and replace all `torch.nn.Linear` modules except for the `lm_head` that should Returns the converted model and a boolean that indicates if the conversion has been successfull or not.
be kept as a `torch.nn.Linear` module. The replacement is done under `init_empty_weights` context manager so no
CPU/GPU memory is required to run this function. Int8 mixed-precision matrix decomposition works by separating a
matrix multiplication into two streams: (1) and systematic feature outlier stream matrix multiplied in fp16
(0.01%), (2) a regular stream of int8 matrix multiplication (99.9%). With this method, int8 inference with no
predictive degradation is possible for very large models (>=176B parameters).
Parameters:
model (`torch.nn.Module`):
Input model or `torch.nn.Module` as the function is run recursively.
modules_to_not_convert (`List[`str`]`, *optional*, defaults to `["lm_head"]`):
Names of the modules to not convert in `Linear8bitLt`. In practice we keep the `lm_head` in full precision
for numerical stability reasons.
current_key_name (`List[`str`]`, *optional*):
An array to track the current key of the recursion. This is used to check whether the current key (part of
it) is not in the list of modules to not convert (for instances modules that are offloaded to `cpu` or
`disk`).
""" """
modules_to_not_convert = ["lm_head"] if modules_to_not_convert is None else modules_to_not_convert has_been_replaced = False
for name, module in model.named_children(): for name, module in model.named_children():
if current_key_name is None: if current_key_name is None:
current_key_name = [] current_key_name = []
...@@ -149,6 +132,7 @@ def replace_with_bnb_linear(model, modules_to_not_convert=None, current_key_name ...@@ -149,6 +132,7 @@ def replace_with_bnb_linear(model, modules_to_not_convert=None, current_key_name
has_fp16_weights=quantization_config.llm_int8_has_fp16_weight, has_fp16_weights=quantization_config.llm_int8_has_fp16_weight,
threshold=quantization_config.llm_int8_threshold, threshold=quantization_config.llm_int8_threshold,
) )
has_been_replaced = True
else: else:
if ( if (
quantization_config.llm_int8_skip_modules is not None quantization_config.llm_int8_skip_modules is not None
...@@ -164,16 +148,59 @@ def replace_with_bnb_linear(model, modules_to_not_convert=None, current_key_name ...@@ -164,16 +148,59 @@ def replace_with_bnb_linear(model, modules_to_not_convert=None, current_key_name
compress_statistics=quantization_config.bnb_4bit_use_double_quant, compress_statistics=quantization_config.bnb_4bit_use_double_quant,
quant_type=quantization_config.bnb_4bit_quant_type, quant_type=quantization_config.bnb_4bit_quant_type,
) )
has_been_replaced = True
# Force requires grad to False to avoid unexpected errors # Force requires grad to False to avoid unexpected errors
model._modules[name].requires_grad_(False) model._modules[name].requires_grad_(False)
# Remove the last key for recursion # Remove the last key for recursion
if len(list(module.children())) > 0: if len(list(module.children())) > 0:
replace_with_bnb_linear( _, has_been_replaced = _replace_with_bnb_linear(
module, module,
modules_to_not_convert, modules_to_not_convert,
current_key_name, current_key_name,
quantization_config, quantization_config,
) )
return model, has_been_replaced
def replace_with_bnb_linear(model, modules_to_not_convert=None, current_key_name=None, quantization_config=None):
"""
A helper function to replace all `torch.nn.Linear` modules by `bnb.nn.Linear8bit` modules from the `bitsandbytes`
library. This will enable running your models using mixed int8 precision as described by the paper `LLM.int8():
8-bit Matrix Multiplication for Transformers at Scale`. Make sure `bitsandbytes` compiled with the correct CUDA
version of your hardware is installed before running this function. `pip install -i https://test.pypi.org/simple/
bitsandbytes`
The function will be run recursively and replace all `torch.nn.Linear` modules except for the `lm_head` that should
be kept as a `torch.nn.Linear` module. The replacement is done under `init_empty_weights` context manager so no
CPU/GPU memory is required to run this function. Int8 mixed-precision matrix decomposition works by separating a
matrix multiplication into two streams: (1) and systematic feature outlier stream matrix multiplied in fp16
(0.01%), (2) a regular stream of int8 matrix multiplication (99.9%). With this method, int8 inference with no
predictive degradation is possible for very large models (>=176B parameters).
Parameters:
model (`torch.nn.Module`):
Input model or `torch.nn.Module` as the function is run recursively.
modules_to_not_convert (`List[`str`]`, *optional*, defaults to `["lm_head"]`):
Names of the modules to not convert in `Linear8bitLt`. In practice we keep the `lm_head` in full precision
for numerical stability reasons.
current_key_name (`List[`str`]`, *optional*):
An array to track the current key of the recursion. This is used to check whether the current key (part of
it) is not in the list of modules to not convert (for instances modules that are offloaded to `cpu` or
`disk`).
"""
modules_to_not_convert = ["lm_head"] if modules_to_not_convert is None else modules_to_not_convert
model, has_been_replaced = _replace_with_bnb_linear(
model, modules_to_not_convert, current_key_name, quantization_config
)
if not has_been_replaced:
logger.warning(
"You are loading your model in 8bit or 4bit but no linear modules were found in your model."
" this can happen for some architectures such as gpt2 that uses Conv1D instead of Linear layers."
" Please double check your model architecture, or submit an issue on github if you think this is"
" a bug."
)
return model return model
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment