Unverified Commit 330d8b99 authored by Andrei Panferov's avatar Andrei Panferov Committed by GitHub
Browse files

replace_8bit_linear modules_to_not_convert default value fix (#22238)



* Fixed modules_to_not_convert default value

* Fixed modules_to_not_convert docstring

* Update src/transformers/utils/bitsandbytes.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update src/transformers/utils/bitsandbytes.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* ["lm_head"] if modules_to_not_convert is None

---------
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>
parent c07a02a4
...@@ -84,7 +84,7 @@ def set_module_8bit_tensor_to_device(module, tensor_name, device, value=None): ...@@ -84,7 +84,7 @@ def set_module_8bit_tensor_to_device(module, tensor_name, device, value=None):
module._parameters[tensor_name] = new_value module._parameters[tensor_name] = new_value
def replace_8bit_linear(model, threshold=6.0, modules_to_not_convert="lm_head", current_key_name=None): def replace_8bit_linear(model, threshold=6.0, modules_to_not_convert=None, current_key_name=None):
""" """
A helper function to replace all `torch.nn.Linear` modules by `bnb.nn.Linear8bit` modules from the `bitsandbytes` A helper function to replace all `torch.nn.Linear` modules by `bnb.nn.Linear8bit` modules from the `bitsandbytes`
library. This will enable running your models using mixed int8 precision as described by the paper `GPT3.int8(): library. This will enable running your models using mixed int8 precision as described by the paper `GPT3.int8():
...@@ -105,14 +105,15 @@ def replace_8bit_linear(model, threshold=6.0, modules_to_not_convert="lm_head", ...@@ -105,14 +105,15 @@ def replace_8bit_linear(model, threshold=6.0, modules_to_not_convert="lm_head",
threshold (`float`, *optional*, defaults to 6.0): threshold (`float`, *optional*, defaults to 6.0):
`int8_threshold` for outlier detection as described in the formentioned paper. This parameters is set to `int8_threshold` for outlier detection as described in the formentioned paper. This parameters is set to
`6.0` as described by the paper. `6.0` as described by the paper.
modules_to_not_convert (`str`, *optional*, defaults to `lm_head`): modules_to_not_convert (`List[`str`]`, *optional*, defaults to `["lm_head"]`):
Name of the module to not convert in `Linear8bitLt`. In practice we keep the `lm_head` in full precision Names of the modules to not convert in `Linear8bitLt`. In practice we keep the `lm_head` in full precision
for numerical stability reasons. for numerical stability reasons.
current_key_name (`List[`str`]`, *optional*): current_key_name (`List[`str`]`, *optional*):
An array to track the current key of the recursion. This is used to check whether the current key (part of An array to track the current key of the recursion. This is used to check whether the current key (part of
it) is not in the list of modules to not convert (for instances modules that are offloaded to `cpu` or it) is not in the list of modules to not convert (for instances modules that are offloaded to `cpu` or
`disk`). `disk`).
""" """
modules_to_not_convert = ["lm_head"] if modules_to_not_convert is None else modules_to_not_convert
for name, module in model.named_children(): for name, module in model.named_children():
if current_key_name is None: if current_key_name is None:
current_key_name = [] current_key_name = []
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment