Unverified Commit 9e7ae568 authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[feat] cache allocator warmup for `from_single_model` (#12305)

* add

* add a test
parent f7b79452
...@@ -22,6 +22,7 @@ from huggingface_hub.utils import validate_hf_hub_args ...@@ -22,6 +22,7 @@ from huggingface_hub.utils import validate_hf_hub_args
from typing_extensions import Self from typing_extensions import Self
from .. import __version__ from .. import __version__
from ..models.model_loading_utils import _caching_allocator_warmup, _determine_device_map, _expand_device_map
from ..quantizers import DiffusersAutoQuantizer from ..quantizers import DiffusersAutoQuantizer
from ..utils import deprecate, is_accelerate_available, is_torch_version, logging from ..utils import deprecate, is_accelerate_available, is_torch_version, logging
from ..utils.torch_utils import empty_device_cache from ..utils.torch_utils import empty_device_cache
...@@ -297,6 +298,7 @@ class FromOriginalModelMixin: ...@@ -297,6 +298,7 @@ class FromOriginalModelMixin:
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT) low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
device = kwargs.pop("device", None) device = kwargs.pop("device", None)
disable_mmap = kwargs.pop("disable_mmap", False) disable_mmap = kwargs.pop("disable_mmap", False)
device_map = kwargs.pop("device_map", None)
user_agent = {"diffusers": __version__, "file_type": "single_file", "framework": "pytorch"} user_agent = {"diffusers": __version__, "file_type": "single_file", "framework": "pytorch"}
# In order to ensure popular quantization methods are supported. Can be disable with `disable_telemetry` # In order to ensure popular quantization methods are supported. Can be disable with `disable_telemetry`
...@@ -403,19 +405,8 @@ class FromOriginalModelMixin: ...@@ -403,19 +405,8 @@ class FromOriginalModelMixin:
with ctx(): with ctx():
model = cls.from_config(diffusers_model_config) model = cls.from_config(diffusers_model_config)
checkpoint_mapping_kwargs = _get_mapping_function_kwargs(checkpoint_mapping_fn, **kwargs) model_state_dict = model.state_dict()
if _should_convert_state_dict_to_diffusers(model.state_dict(), checkpoint):
diffusers_format_checkpoint = checkpoint_mapping_fn(
config=diffusers_model_config, checkpoint=checkpoint, **checkpoint_mapping_kwargs
)
else:
diffusers_format_checkpoint = checkpoint
if not diffusers_format_checkpoint:
raise SingleFileComponentError(
f"Failed to load {mapping_class_name}. Weights for this component appear to be missing in the checkpoint."
)
# Check if `_keep_in_fp32_modules` is not None # Check if `_keep_in_fp32_modules` is not None
use_keep_in_fp32_modules = (cls._keep_in_fp32_modules is not None) and ( use_keep_in_fp32_modules = (cls._keep_in_fp32_modules is not None) and (
(torch_dtype == torch.float16) or hasattr(hf_quantizer, "use_keep_in_fp32_modules") (torch_dtype == torch.float16) or hasattr(hf_quantizer, "use_keep_in_fp32_modules")
...@@ -428,6 +419,26 @@ class FromOriginalModelMixin: ...@@ -428,6 +419,26 @@ class FromOriginalModelMixin:
else: else:
keep_in_fp32_modules = [] keep_in_fp32_modules = []
# Now that the model is loaded, we can determine the `device_map`
device_map = _determine_device_map(model, device_map, None, torch_dtype, keep_in_fp32_modules, hf_quantizer)
if device_map is not None:
expanded_device_map = _expand_device_map(device_map, model_state_dict.keys())
_caching_allocator_warmup(model, expanded_device_map, torch_dtype, hf_quantizer)
checkpoint_mapping_kwargs = _get_mapping_function_kwargs(checkpoint_mapping_fn, **kwargs)
if _should_convert_state_dict_to_diffusers(model_state_dict, checkpoint):
diffusers_format_checkpoint = checkpoint_mapping_fn(
config=diffusers_model_config, checkpoint=checkpoint, **checkpoint_mapping_kwargs
)
else:
diffusers_format_checkpoint = checkpoint
if not diffusers_format_checkpoint:
raise SingleFileComponentError(
f"Failed to load {mapping_class_name}. Weights for this component appear to be missing in the checkpoint."
)
if hf_quantizer is not None: if hf_quantizer is not None:
hf_quantizer.preprocess_model( hf_quantizer.preprocess_model(
model=model, model=model,
......
...@@ -69,3 +69,11 @@ class FluxTransformer2DModelSingleFileTests(unittest.TestCase): ...@@ -69,3 +69,11 @@ class FluxTransformer2DModelSingleFileTests(unittest.TestCase):
del model del model
gc.collect() gc.collect()
backend_empty_cache(torch_device) backend_empty_cache(torch_device)
def test_device_map_cuda(self):
backend_empty_cache(torch_device)
model = self.model_class.from_single_file(self.ckpt_path, device_map="cuda")
del model
gc.collect()
backend_empty_cache(torch_device)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment