Unverified Commit 72282876 authored by IrisRainbowNeko's avatar IrisRainbowNeko Committed by GitHub
Browse files

Add low_cpu_mem_usage option to from_single_file to align with from_pretrained (#12114)



* align meta device of from_single_file with from_pretrained

* update docstr

* Apply style fixes

---------
Co-authored-by: default avatarIrisRainbowNeko <rainbow-neko@outlook.com>
Co-authored-by: default avatargithub-actions[bot] <github-actions[bot]@users.noreply.github.com>
parent 3552279a
......@@ -23,7 +23,7 @@ from typing_extensions import Self
from .. import __version__
from ..quantizers import DiffusersAutoQuantizer
from ..utils import deprecate, is_accelerate_available, logging
from ..utils import deprecate, is_accelerate_available, is_torch_version, logging
from ..utils.torch_utils import empty_device_cache
from .single_file_utils import (
SingleFileComponentError,
......@@ -64,6 +64,10 @@ if is_accelerate_available():
from ..models.modeling_utils import load_model_dict_into_meta
if is_torch_version(">=", "1.9.0") and is_accelerate_available():
_LOW_CPU_MEM_USAGE_DEFAULT = True
else:
_LOW_CPU_MEM_USAGE_DEFAULT = False
SINGLE_FILE_LOADABLE_CLASSES = {
"StableCascadeUNet": {
......@@ -236,6 +240,11 @@ class FromOriginalModelMixin:
revision (`str`, *optional*, defaults to `"main"`):
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
allowed by Git.
low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 and
is_accelerate_available() else `False`): Speed up model loading only loading the pretrained weights and
not initializing the weights. This also tries to not use more than 1x model size in CPU memory
(including peak memory) while loading the model. Only supported for PyTorch >= 1.9.0. If you are using
an older version of PyTorch, setting this argument to `True` will raise an error.
disable_mmap ('bool', *optional*, defaults to 'False'):
Whether to disable mmap when loading a Safetensors model. This option can perform better when the model
is on a network mount or hard drive, which may not handle the seeky-ness of mmap very well.
......@@ -285,6 +294,7 @@ class FromOriginalModelMixin:
config_revision = kwargs.pop("config_revision", None)
torch_dtype = kwargs.pop("torch_dtype", None)
quantization_config = kwargs.pop("quantization_config", None)
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
device = kwargs.pop("device", None)
disable_mmap = kwargs.pop("disable_mmap", False)
......@@ -389,7 +399,7 @@ class FromOriginalModelMixin:
model_kwargs = {k: kwargs.get(k) for k in kwargs if k in expected_kwargs or k in optional_kwargs}
diffusers_model_config.update(model_kwargs)
ctx = init_empty_weights if is_accelerate_available() else nullcontext
ctx = init_empty_weights if low_cpu_mem_usage else nullcontext
with ctx():
model = cls.from_config(diffusers_model_config)
......@@ -427,7 +437,7 @@ class FromOriginalModelMixin:
)
device_map = None
if is_accelerate_available():
if low_cpu_mem_usage:
param_device = torch.device(device) if device else torch.device("cpu")
empty_state_dict = model.state_dict()
unexpected_keys = [
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment