Unverified Commit 7743cacc authored by Younes Belkada's avatar Younes Belkada Committed by GitHub
Browse files

[bnb] Small improvements on utils (#18646)



* Small replacement

- replace `modules_to_not_convert` by `module_to_not_convert`

* refactor a bit

- changed variables name
- now output a list
- change error message

* make style

* add list

* make style

* change args name
Co-authored-by: default avatarstas00 <stas00@users.noreply.github.com>

* fix comment

* fix typo
Co-authored-by: default avatarstas00 <stas00@users.noreply.github.com>

* Update src/transformers/modeling_utils.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: default avatarstas00 <stas00@users.noreply.github.com>
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
parent 8edf1963
...@@ -1751,7 +1751,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -1751,7 +1751,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
https://test.pypi.org/simple/ bitsandbytes-cudaXXX` where XXX is your CUDA version (e.g. 11.6 = 116). https://test.pypi.org/simple/ bitsandbytes-cudaXXX` where XXX is your CUDA version (e.g. 11.6 = 116).
Make also sure that you have enough GPU RAM to store half of the model size since the 8bit modules are Make also sure that you have enough GPU RAM to store half of the model size since the 8bit modules are
not compiled and adapted for CPUs. not compiled and adapted for CPUs.
int8_threshold (`float`, *optional*, defaults to 6): load_in_8bit_threshold (`float`, *optional*, defaults to 6):
Works together with `load_in_8bit`. This corresponds to the outlier threshold for outlier detection as Works together with `load_in_8bit`. This corresponds to the outlier threshold for outlier detection as
described in `GPT3.int8() : 8-bit Matrix Multiplication for Transformers at Scale` paper. Any hidden described in `GPT3.int8() : 8-bit Matrix Multiplication for Transformers at Scale` paper. Any hidden
states value that is above this threshold will be considered an outlier and the operation on those states value that is above this threshold will be considered an outlier and the operation on those
...@@ -1761,6 +1761,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -1761,6 +1761,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
quantization works well for values of magnitude ~5, but beyond that, there is a significant performance quantization works well for values of magnitude ~5, but beyond that, there is a significant performance
penalty. A good default threshold is 6, but a lower threshold might be needed for more unstable models penalty. A good default threshold is 6, but a lower threshold might be needed for more unstable models
(small models, fine-tuning). (small models, fine-tuning).
load_in_8bit_skip_modules (`List[str]`, *optional*):
An explicit list of the modules that we do not want to convert in 8-bit. This is useful for models such
as Jukebox that has several heads in different places and not necessarily at the last position.
subfolder (`str`, *optional*, defaults to `""`): subfolder (`str`, *optional*, defaults to `""`):
In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can
specify the folder name here. specify the folder name here.
...@@ -1852,7 +1855,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -1852,7 +1855,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
offload_folder = kwargs.pop("offload_folder", None) offload_folder = kwargs.pop("offload_folder", None)
offload_state_dict = kwargs.pop("offload_state_dict", False) offload_state_dict = kwargs.pop("offload_state_dict", False)
load_in_8bit = kwargs.pop("load_in_8bit", False) load_in_8bit = kwargs.pop("load_in_8bit", False)
int8_threshold = kwargs.pop("int8_threshold", 6.0) load_in_8bit_threshold = kwargs.pop("load_in_8bit_threshold", 6.0)
load_in_8bit_skip_modules = kwargs.pop("load_in_8bit_skip_modules", None)
subfolder = kwargs.pop("subfolder", "") subfolder = kwargs.pop("subfolder", "")
commit_hash = kwargs.pop("_commit_hash", None) commit_hash = kwargs.pop("_commit_hash", None)
...@@ -2156,13 +2160,18 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -2156,13 +2160,18 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
model = cls(config, *model_args, **model_kwargs) model = cls(config, *model_args, **model_kwargs)
if load_in_8bit: if load_in_8bit:
from .utils.bitsandbytes import get_key_to_not_convert, replace_8bit_linear from .utils.bitsandbytes import get_keys_to_not_convert, replace_8bit_linear
logger.info("Detected 8-bit loading: activating 8-bit loading for this model") logger.info("Detected 8-bit loading: activating 8-bit loading for this model")
# We never convert lm_head or any last modules for numerical stability reasons # We keep some modules such as the lm_head in their original dtype for numerical stability reasons
modules_to_not_convert = get_key_to_not_convert(model) if load_in_8bit_skip_modules is None:
model = replace_8bit_linear(model, threshold=int8_threshold, modules_to_not_convert=modules_to_not_convert) modules_to_not_convert = get_keys_to_not_convert(model)
else:
modules_to_not_convert = load_in_8bit_skip_modules
model = replace_8bit_linear(
model, threshold=load_in_8bit_threshold, modules_to_not_convert=modules_to_not_convert
)
if isinstance(device_map, str): if isinstance(device_map, str):
if model._no_split_modules is None: if model._no_split_modules is None:
...@@ -2193,12 +2202,18 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -2193,12 +2202,18 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
) )
if load_in_8bit: if load_in_8bit:
# The LM head can stay on disk / CPU # The LM head / tied weights or any last module can stay on disk / CPU
device_map_without_lm_head = { device_map_without_lm_head = {
key: device_map[key] for key in device_map.keys() if key != modules_to_not_convert key: device_map[key] for key in device_map.keys() if key not in modules_to_not_convert
} }
if "cpu" in device_map_without_lm_head.values() or "disk" in device_map_without_lm_head.values(): if "cpu" in device_map_without_lm_head.values() or "disk" in device_map_without_lm_head.values():
raise ValueError("8-bit operations on `bitsandbytes` are not supported under CPU!") raise ValueError(
"""
Some modules are dispatched on the CPU or the disk. Make sure you have enough GPU RAM to fit
the quantized model. If you have set a value for `max_memory` you should increase that. To have
an idea of the modules that are set on the CPU or RAM you can print model.hf_device_map.
"""
)
del device_map_without_lm_head del device_map_without_lm_head
if from_tf: if from_tf:
......
...@@ -114,7 +114,7 @@ def replace_8bit_linear(model, threshold=6.0, modules_to_not_convert="lm_head"): ...@@ -114,7 +114,7 @@ def replace_8bit_linear(model, threshold=6.0, modules_to_not_convert="lm_head"):
if len(list(module.children())) > 0: if len(list(module.children())) > 0:
replace_8bit_linear(module, threshold, modules_to_not_convert) replace_8bit_linear(module, threshold, modules_to_not_convert)
if isinstance(module, nn.Linear) and name != modules_to_not_convert: if isinstance(module, nn.Linear) and name not in modules_to_not_convert:
with init_empty_weights(): with init_empty_weights():
model._modules[name] = bnb.nn.Linear8bitLt( model._modules[name] = bnb.nn.Linear8bitLt(
module.in_features, module.in_features,
...@@ -126,10 +126,12 @@ def replace_8bit_linear(model, threshold=6.0, modules_to_not_convert="lm_head"): ...@@ -126,10 +126,12 @@ def replace_8bit_linear(model, threshold=6.0, modules_to_not_convert="lm_head"):
return model return model
def get_key_to_not_convert(model): def get_keys_to_not_convert(model):
r""" r"""
An utility function to get the key of the module to keep in full precision if any For example for CausalLM modules An utility function to get the key of the module to keep in full precision if any For example for CausalLM modules
we may want to keep the lm_head in full precision for numerical stability reasons. we may want to keep the lm_head in full precision for numerical stability reasons. For other architectures, we want
to keep the tied weights of the model. The function will return a list of the keys of the modules to not convert in
int8.
Parameters: Parameters:
model (`torch.nn.Module`): model (`torch.nn.Module`):
...@@ -139,7 +141,9 @@ def get_key_to_not_convert(model): ...@@ -139,7 +141,9 @@ def get_key_to_not_convert(model):
# check if it contains tied weights # check if it contains tied weights
tied_model = deepcopy(model) # this has 0 cost since it is done inside `init_empty_weights` context manager` tied_model = deepcopy(model) # this has 0 cost since it is done inside `init_empty_weights` context manager`
tied_model.tie_weights() tied_model.tie_weights()
has_tied_params = len(find_tied_parameters(tied_model)) > 0
tied_keys = list(find_tied_parameters(tied_model).values())
has_tied_params = len(tied_keys) > 0
# Check if it is a base model # Check if it is a base model
is_base_model = not hasattr(model, model.base_model_prefix) is_base_model = not hasattr(model, model.base_model_prefix)
...@@ -150,5 +154,10 @@ def get_key_to_not_convert(model): ...@@ -150,5 +154,10 @@ def get_key_to_not_convert(model):
# otherwise they have an attached head # otherwise they have an attached head
list_modules = list(model.named_parameters()) list_modules = list(model.named_parameters())
last_name = list_modules[-1][0] list_last_module = [list_modules[-1][0]]
return last_name.split(".")[0]
# add last module together with tied weights
intersection = set(list_last_module) - set(tied_keys)
list_untouched = tied_keys + list(intersection)
return [module_name.split(".")[0] for module_name in list_untouched]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment