Unverified Commit 42bb4594 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Low cpu memory] Correct naming and improve default usage (#1122)



* correct naming

* finish

* Apply suggestions from code review

* Apply suggestions from code review
Co-authored-by: default avatarSuraj Patil <surajp815@gmail.com>
Co-authored-by: default avatarSuraj Patil <surajp815@gmail.com>
parent 988c8222
...@@ -35,6 +35,12 @@ from .utils import CONFIG_NAME, DIFFUSERS_CACHE, HUGGINGFACE_CO_RESOLVE_ENDPOINT ...@@ -35,6 +35,12 @@ from .utils import CONFIG_NAME, DIFFUSERS_CACHE, HUGGINGFACE_CO_RESOLVE_ENDPOINT
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
if is_torch_version(">=", "1.9.0"):
_LOW_CPU_MEM_USAGE_DEFAULT = True
else:
_LOW_CPU_MEM_USAGE_DEFAULT = False
def get_parameter_device(parameter: torch.nn.Module): def get_parameter_device(parameter: torch.nn.Module):
try: try:
return next(parameter.parameters()).device return next(parameter.parameters()).device
...@@ -278,11 +284,11 @@ class ModelMixin(torch.nn.Module): ...@@ -278,11 +284,11 @@ class ModelMixin(torch.nn.Module):
To have Accelerate compute the most optimized `device_map` automatically, set `device_map="auto"`. For To have Accelerate compute the most optimized `device_map` automatically, set `device_map="auto"`. For
more information about each option see [designing a device more information about each option see [designing a device
map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map). map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map).
fast_load (`bool`, *optional*, defaults to `True`): low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
Speed up model loading by not initializing the weights and only loading the pre-trained weights. This Speed up model loading by not initializing the weights and only loading the pre-trained weights. This
also tries to not use more than 1x model size in CPU memory (including peak memory) while loading the also tries to not use more than 1x model size in CPU memory (including peak memory) while loading the
model. This is only supported when torch version >= 1.9.0. If you are using an older version of torch, model. This is only supported when torch version >= 1.9.0. If you are using an older version of torch,
this argument will be ignored and the model will be loaded normally. setting this argument to `True` will raise an error.
<Tip> <Tip>
...@@ -311,16 +317,26 @@ class ModelMixin(torch.nn.Module): ...@@ -311,16 +317,26 @@ class ModelMixin(torch.nn.Module):
torch_dtype = kwargs.pop("torch_dtype", None) torch_dtype = kwargs.pop("torch_dtype", None)
subfolder = kwargs.pop("subfolder", None) subfolder = kwargs.pop("subfolder", None)
device_map = kwargs.pop("device_map", None) device_map = kwargs.pop("device_map", None)
fast_load = kwargs.pop("fast_load", True) low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
# Check if we can handle device_map and dispatching the weights # Check if we can handle device_map and dispatching the weights
if device_map is not None and not is_torch_version(">=", "1.9.0"): if device_map is not None and not is_torch_version(">=", "1.9.0"):
raise NotImplementedError("Loading and dispatching requires torch >= 1.9.0") raise NotImplementedError(
"Loading and dispatching requires torch >= 1.9.0. Please either update your PyTorch version or set"
" `device_map=None`."
)
# Fast init is only possible if torch version is >= 1.9.0 if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
_INIT_EMPTY_WEIGHTS = fast_load or device_map is not None raise NotImplementedError(
if _INIT_EMPTY_WEIGHTS and not is_torch_version(">=", "1.9.0"): "Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
logger.warn("Loading with `fast_load` requires torch >= 1.9.0. Falling back to normal loading.") " `low_cpu_mem_usage=False`."
)
if low_cpu_mem_usage is False and device_map is not None:
raise ValueError(
f"You cannot set `low_cpu_mem_usage` to `False` while using device_map={device_map} for loading and"
" dispatching. Please make sure to set `low_cpu_mem_usage=True`."
)
user_agent = { user_agent = {
"diffusers": __version__, "diffusers": __version__,
...@@ -403,7 +419,7 @@ class ModelMixin(torch.nn.Module): ...@@ -403,7 +419,7 @@ class ModelMixin(torch.nn.Module):
# restore default dtype # restore default dtype
if _INIT_EMPTY_WEIGHTS: if low_cpu_mem_usage:
# Instantiate model with empty weights # Instantiate model with empty weights
with accelerate.init_empty_weights(): with accelerate.init_empty_weights():
model, unused_kwargs = cls.from_config( model, unused_kwargs = cls.from_config(
......
...@@ -25,6 +25,7 @@ import torch ...@@ -25,6 +25,7 @@ import torch
import diffusers import diffusers
import PIL import PIL
from accelerate.utils.versions import is_torch_version
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
from packaging import version from packaging import version
from PIL import Image from PIL import Image
...@@ -33,6 +34,7 @@ from tqdm.auto import tqdm ...@@ -33,6 +34,7 @@ from tqdm.auto import tqdm
from .configuration_utils import ConfigMixin from .configuration_utils import ConfigMixin
from .dynamic_modules_utils import get_class_from_dynamic_module from .dynamic_modules_utils import get_class_from_dynamic_module
from .hub_utils import http_user_agent from .hub_utils import http_user_agent
from .modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT
from .schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME from .schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME
from .utils import ( from .utils import (
CONFIG_NAME, CONFIG_NAME,
...@@ -328,6 +330,19 @@ class DiffusionPipeline(ConfigMixin): ...@@ -328,6 +330,19 @@ class DiffusionPipeline(ConfigMixin):
Mirror source to accelerate downloads in China. If you are from China and have an accessibility Mirror source to accelerate downloads in China. If you are from China and have an accessibility
problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety. problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety.
Please refer to the mirror site for more information. specify the folder name here. Please refer to the mirror site for more information. specify the folder name here.
device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*):
A map that specifies where each submodule should go. It doesn't need to be refined to each
parameter/buffer name, once a given module name is inside, every submodule of it will be sent to the
same device.
To have Accelerate compute the most optimized `device_map` automatically, set `device_map="auto"`. For
more information about each option see [designing a device
map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map).
low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
Speed up model loading by not initializing the weights and only loading the pre-trained weights. This
also tries to not use more than 1x model size in CPU memory (including peak memory) while loading the
model. This is only supported when torch version >= 1.9.0. If you are using an older version of torch,
setting this argument to `True` will raise an error.
kwargs (remaining dictionary of keyword arguments, *optional*): kwargs (remaining dictionary of keyword arguments, *optional*):
Can be used to overwrite load - and saveable variables - *i.e.* the pipeline components - of the Can be used to overwrite load - and saveable variables - *i.e.* the pipeline components - of the
...@@ -380,7 +395,25 @@ class DiffusionPipeline(ConfigMixin): ...@@ -380,7 +395,25 @@ class DiffusionPipeline(ConfigMixin):
provider = kwargs.pop("provider", None) provider = kwargs.pop("provider", None)
sess_options = kwargs.pop("sess_options", None) sess_options = kwargs.pop("sess_options", None)
device_map = kwargs.pop("device_map", None) device_map = kwargs.pop("device_map", None)
fast_load = kwargs.pop("fast_load", True) low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
if device_map is not None and not is_torch_version(">=", "1.9.0"):
raise NotImplementedError(
"Loading and dispatching requires torch >= 1.9.0. Please either update your PyTorch version or set"
" `device_map=None`."
)
if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
raise NotImplementedError(
"Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
" `low_cpu_mem_usage=False`."
)
if low_cpu_mem_usage is False and device_map is not None:
raise ValueError(
f"You cannot set `low_cpu_mem_usage` to False while using device_map={device_map} for loading and"
" dispatching. Please make sure to set `low_cpu_mem_usage=True`."
)
# 1. Download the checkpoints and configs # 1. Download the checkpoints and configs
# use snapshot download here to get it working from from_pretrained # use snapshot download here to get it working from from_pretrained
...@@ -573,17 +606,12 @@ class DiffusionPipeline(ConfigMixin): ...@@ -573,17 +606,12 @@ class DiffusionPipeline(ConfigMixin):
and version.parse(version.parse(transformers.__version__).base_version) >= version.parse("4.20.0") and version.parse(version.parse(transformers.__version__).base_version) >= version.parse("4.20.0")
) )
if is_diffusers_model:
loading_kwargs["fast_load"] = fast_load
# When loading a transformers model, if the device_map is None, the weights will be initialized as opposed to diffusers. # When loading a transformers model, if the device_map is None, the weights will be initialized as opposed to diffusers.
# To make default loading faster we set the `low_cpu_mem_usage=fast_load` flag which is `True` by default. # To make default loading faster we set the `low_cpu_mem_usage=low_cpu_mem_usage` flag which is `True` by default.
# This makes sure that the weights won't be initialized which significantly speeds up loading. # This makes sure that the weights won't be initialized which significantly speeds up loading.
if is_transformers_model and device_map is None:
loading_kwargs["low_cpu_mem_usage"] = fast_load
if is_diffusers_model or is_transformers_model: if is_diffusers_model or is_transformers_model:
loading_kwargs["device_map"] = device_map loading_kwargs["device_map"] = device_map
loading_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
# check if the module is in a subdirectory # check if the module is in a subdirectory
if os.path.isdir(os.path.join(cached_folder, name)): if os.path.isdir(os.path.join(cached_folder, name)):
......
...@@ -133,7 +133,7 @@ class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase): ...@@ -133,7 +133,7 @@ class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase):
@unittest.skipIf(torch_device != "cuda", "This test is supposed to run on GPU") @unittest.skipIf(torch_device != "cuda", "This test is supposed to run on GPU")
def test_from_pretrained_accelerate_wont_change_results(self): def test_from_pretrained_accelerate_wont_change_results(self):
# by defautl model loading will use accelerate as `fast_load=True` # by defautl model loading will use accelerate as `low_cpu_mem_usage=True`
model_accelerate, _ = UNet2DModel.from_pretrained("fusing/unet-ldm-dummy-update", output_loading_info=True) model_accelerate, _ = UNet2DModel.from_pretrained("fusing/unet-ldm-dummy-update", output_loading_info=True)
model_accelerate.to(torch_device) model_accelerate.to(torch_device)
model_accelerate.eval() model_accelerate.eval()
...@@ -156,7 +156,7 @@ class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase): ...@@ -156,7 +156,7 @@ class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase):
gc.collect() gc.collect()
model_normal_load, _ = UNet2DModel.from_pretrained( model_normal_load, _ = UNet2DModel.from_pretrained(
"fusing/unet-ldm-dummy-update", output_loading_info=True, fast_init=False "fusing/unet-ldm-dummy-update", output_loading_info=True, low_cpu_mem_usage=False
) )
model_normal_load.to(torch_device) model_normal_load.to(torch_device)
model_normal_load.eval() model_normal_load.eval()
...@@ -170,7 +170,7 @@ class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase): ...@@ -170,7 +170,7 @@ class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase):
gc.collect() gc.collect()
tracemalloc.start() tracemalloc.start()
# by defautl model loading will use accelerate as `fast_load=True` # by defautl model loading will use accelerate as `low_cpu_mem_usage=True`
model_accelerate, _ = UNet2DModel.from_pretrained("fusing/unet-ldm-dummy-update", output_loading_info=True) model_accelerate, _ = UNet2DModel.from_pretrained("fusing/unet-ldm-dummy-update", output_loading_info=True)
model_accelerate.to(torch_device) model_accelerate.to(torch_device)
model_accelerate.eval() model_accelerate.eval()
...@@ -181,7 +181,7 @@ class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase): ...@@ -181,7 +181,7 @@ class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase):
gc.collect() gc.collect()
model_normal_load, _ = UNet2DModel.from_pretrained( model_normal_load, _ = UNet2DModel.from_pretrained(
"fusing/unet-ldm-dummy-update", output_loading_info=True, fast_init=False "fusing/unet-ldm-dummy-update", output_loading_info=True, low_cpu_mem_usage=False
) )
model_normal_load.to(torch_device) model_normal_load.to(torch_device)
model_normal_load.eval() model_normal_load.eval()
......
...@@ -823,23 +823,23 @@ class StableDiffusionPipelineIntegrationTests(unittest.TestCase): ...@@ -823,23 +823,23 @@ class StableDiffusionPipelineIntegrationTests(unittest.TestCase):
assert test_callback_fn.has_been_called assert test_callback_fn.has_been_called
assert number_of_steps == 51 assert number_of_steps == 51
def test_stable_diffusion_fast_load(self): def test_stable_diffusion_low_cpu_mem_usage(self):
pipeline_id = "CompVis/stable-diffusion-v1-4" pipeline_id = "CompVis/stable-diffusion-v1-4"
start_time = time.time() start_time = time.time()
pipeline_fast_load = StableDiffusionPipeline.from_pretrained( pipeline_low_cpu_mem_usage = StableDiffusionPipeline.from_pretrained(
pipeline_id, revision="fp16", torch_dtype=torch.float16 pipeline_id, revision="fp16", torch_dtype=torch.float16
) )
pipeline_fast_load.to(torch_device) pipeline_low_cpu_mem_usage.to(torch_device)
fast_load_time = time.time() - start_time low_cpu_mem_usage_time = time.time() - start_time
start_time = time.time() start_time = time.time()
_ = StableDiffusionPipeline.from_pretrained( _ = StableDiffusionPipeline.from_pretrained(
pipeline_id, revision="fp16", torch_dtype=torch.float16, use_auth_token=True, fast_load=False pipeline_id, revision="fp16", torch_dtype=torch.float16, use_auth_token=True, low_cpu_mem_usage=False
) )
normal_load_time = time.time() - start_time normal_load_time = time.time() - start_time
assert 2 * fast_load_time < normal_load_time assert 2 * low_cpu_mem_usage_time < normal_load_time
@unittest.skipIf(torch_device == "cpu", "This test is supposed to run on GPU") @unittest.skipIf(torch_device == "cpu", "This test is supposed to run on GPU")
def test_stable_diffusion_pipeline_with_unet_on_gpu_only(self): def test_stable_diffusion_pipeline_with_unet_on_gpu_only(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment