Unverified Commit 4d1cce2f authored by Pi Esposito's avatar Pi Esposito Committed by GitHub
Browse files

add accelerate to load models with smaller memory footprint (#361)



* add accelerate to load models with smaller memory footprint

* remove low_cpu_mem_usage as it is reduntant

* move accelerate init weights context to modelling utils

* add test to ensure results are the same when loading with accelerate

* add tests to ensure ram usage gets lower when using accelerate

* move accelerate logic to single snippet under modelling utils and remove it from configuration utils

* format code using to pass quality check

* fix imports with isor

* add accelerate to test extra deps

* only import accelerate if device_map is set to auto

* move accelerate availability check to diffusers import utils

* format code
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent 09859a3c
......@@ -104,6 +104,7 @@ _deps = [
"torch>=1.4",
"torchvision",
"transformers>=4.21.0",
"accelerate>=0.12.0"
]
# this is a lookup table with items like:
......
......@@ -21,6 +21,7 @@ from typing import Callable, List, Optional, Tuple, Union
import torch
from torch import Tensor, device
from diffusers.utils import is_accelerate_available
from huggingface_hub import hf_hub_download
from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError
from requests import HTTPError
......@@ -293,33 +294,13 @@ class ModelMixin(torch.nn.Module):
from_auto_class = kwargs.pop("_from_auto", False)
torch_dtype = kwargs.pop("torch_dtype", None)
subfolder = kwargs.pop("subfolder", None)
device_map = kwargs.pop("device_map", None)
user_agent = {"file_type": "model", "framework": "pytorch", "from_auto_class": from_auto_class}
# Load config if we don't provide a configuration
config_path = pretrained_model_name_or_path
model, unused_kwargs = cls.from_config(
config_path,
cache_dir=cache_dir,
return_unused_kwargs=True,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
revision=revision,
subfolder=subfolder,
**kwargs,
)
if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype):
raise ValueError(
f"{torch_dtype} needs to be of type `torch.dtype`, e.g. `torch.float16`, but is {type(torch_dtype)}."
)
elif torch_dtype is not None:
model = model.to(torch_dtype)
model.register_to_config(_name_or_path=pretrained_model_name_or_path)
# This variable will flag if we're loading a sharded checkpoint. In this case the archive file is just the
# Load model
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
......@@ -391,25 +372,81 @@ class ModelMixin(torch.nn.Module):
)
# restore default dtype
state_dict = load_state_dict(model_file)
model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model(
model,
state_dict,
model_file,
pretrained_model_name_or_path,
ignore_mismatched_sizes=ignore_mismatched_sizes,
)
# Set model in evaluation mode to deactivate DropOut modules by default
model.eval()
if device_map == "auto":
if is_accelerate_available():
import accelerate
else:
raise ImportError("Please install accelerate via `pip install accelerate`")
with accelerate.init_empty_weights():
model, unused_kwargs = cls.from_config(
config_path,
cache_dir=cache_dir,
return_unused_kwargs=True,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
revision=revision,
subfolder=subfolder,
device_map=device_map,
**kwargs,
)
accelerate.load_checkpoint_and_dispatch(model, model_file, device_map)
loading_info = {
"missing_keys": [],
"unexpected_keys": [],
"mismatched_keys": [],
"error_msgs": [],
}
else:
model, unused_kwargs = cls.from_config(
config_path,
cache_dir=cache_dir,
return_unused_kwargs=True,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
revision=revision,
subfolder=subfolder,
device_map=device_map,
**kwargs,
)
state_dict = load_state_dict(model_file)
model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model(
model,
state_dict,
model_file,
pretrained_model_name_or_path,
ignore_mismatched_sizes=ignore_mismatched_sizes,
)
if output_loading_info:
loading_info = {
"missing_keys": missing_keys,
"unexpected_keys": unexpected_keys,
"mismatched_keys": mismatched_keys,
"error_msgs": error_msgs,
}
if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype):
raise ValueError(
f"{torch_dtype} needs to be of type `torch.dtype`, e.g. `torch.float16`, but is {type(torch_dtype)}."
)
elif torch_dtype is not None:
model = model.to(torch_dtype)
model.register_to_config(_name_or_path=pretrained_model_name_or_path)
# Set model in evaluation mode to deactivate DropOut modules by default
model.eval()
if output_loading_info:
return model, loading_info
return model
......
......@@ -23,6 +23,7 @@ from .import_utils import (
USE_TF,
USE_TORCH,
DummyObject,
is_accelerate_available,
is_flax_available,
is_inflect_available,
is_modelcards_available,
......
......@@ -159,6 +159,13 @@ try:
except importlib_metadata.PackageNotFoundError:
_scipy_available = False
_accelerate_available = importlib.util.find_spec("accelerate") is not None
try:
_accelerate_version = importlib_metadata.version("accelerate")
logger.debug(f"Successfully imported accelerate version {_accelerate_version}")
except importlib_metadata.PackageNotFoundError:
_accelerate_available = False
def is_torch_available():
return _torch_available
......@@ -196,6 +203,10 @@ def is_scipy_available():
return _scipy_available
def is_accelerate_available():
return _accelerate_available
# docstyle-ignore
FLAX_IMPORT_ERROR = """
{0} requires the FLAX library but it was not found in your environment. Checkout the instructions on the
......
......@@ -13,7 +13,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import math
import tracemalloc
import unittest
import torch
......@@ -133,6 +135,74 @@ class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase):
assert image is not None, "Make sure output is not None"
@unittest.skipIf(torch_device == "cpu", "This test is supposed to run on GPU")
def test_from_pretrained_accelerate(self):
model, _ = UNet2DModel.from_pretrained(
"fusing/unet-ldm-dummy-update", output_loading_info=True, device_map="auto"
)
model.to(torch_device)
image = model(**self.dummy_input).sample
assert image is not None, "Make sure output is not None"
@unittest.skipIf(torch_device == "cpu", "This test is supposed to run on GPU")
def test_from_pretrained_accelerate_wont_change_results(self):
model_accelerate, _ = UNet2DModel.from_pretrained(
"fusing/unet-ldm-dummy-update", output_loading_info=True, device_map="auto"
)
model_accelerate.to(torch_device)
model_accelerate.eval()
noise = torch.randn(
1,
model_accelerate.config.in_channels,
model_accelerate.config.sample_size,
model_accelerate.config.sample_size,
generator=torch.manual_seed(0),
)
noise = noise.to(torch_device)
time_step = torch.tensor([10] * noise.shape[0]).to(torch_device)
arr_accelerate = model_accelerate(noise, time_step)["sample"]
# two models don't need to stay in the device at the same time
del model_accelerate
torch.cuda.empty_cache()
gc.collect()
model_normal_load, _ = UNet2DModel.from_pretrained("fusing/unet-ldm-dummy-update", output_loading_info=True)
model_normal_load.to(torch_device)
model_normal_load.eval()
arr_normal_load = model_normal_load(noise, time_step)["sample"]
assert torch.allclose(arr_accelerate, arr_normal_load, rtol=1e-3)
@unittest.skipIf(torch_device == "cpu", "This test is supposed to run on GPU")
def test_memory_footprint_gets_reduced(self):
torch.cuda.empty_cache()
gc.collect()
tracemalloc.start()
model_accelerate, _ = UNet2DModel.from_pretrained(
"fusing/unet-ldm-dummy-update", output_loading_info=True, device_map="auto"
)
model_accelerate.to(torch_device)
model_accelerate.eval()
_, peak_accelerate = tracemalloc.get_traced_memory()
del model_accelerate
torch.cuda.empty_cache()
gc.collect()
model_normal_load, _ = UNet2DModel.from_pretrained("fusing/unet-ldm-dummy-update", output_loading_info=True)
model_normal_load.to(torch_device)
model_normal_load.eval()
_, peak_normal = tracemalloc.get_traced_memory()
tracemalloc.stop()
assert peak_accelerate < peak_normal
def test_output_pretrained(self):
model = UNet2DModel.from_pretrained("fusing/unet-ldm-dummy-update")
model.eval()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment