Unverified Commit c5be7cae authored by Stas Bekman's avatar Stas Bekman Committed by GitHub
Browse files

postpone bnb load until it's needed (#18859)

parent 9e346f74
......@@ -84,8 +84,6 @@ if is_accelerate_available():
else:
get_balanced_memory = None
if is_bitsandbytes_available():
from .utils.bitsandbytes import get_key_to_not_convert, replace_8bit_linear, set_module_8bit_tensor_to_device
logger = logging.get_logger(__name__)
......@@ -527,6 +525,9 @@ def _load_state_dict_into_meta_model(
# - Is there a situation where some keys aren't in `loaded_state_dict_keys` and in which case
# they won't get loaded.
if load_in_8bit:
from .utils.bitsandbytes import set_module_8bit_tensor_to_device
error_msgs = []
old_keys = []
......@@ -2142,6 +2143,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
model = cls(config, *model_args, **model_kwargs)
if load_in_8bit:
from .utils.bitsandbytes import get_key_to_not_convert, replace_8bit_linear
logger.info("Detected 8-bit loading: activating 8-bit loading for this model")
# We never convert lm_head or any last modules for numerical stability reasons
......@@ -2279,6 +2282,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
dtype=None,
load_in_8bit=False,
):
if load_in_8bit:
from .utils.bitsandbytes import set_module_8bit_tensor_to_device
if device_map is not None and "disk" in device_map.values():
if offload_folder is None:
raise ValueError(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment