Unverified Commit d75ea3c7 authored by hlky's avatar hlky Committed by GitHub
Browse files

`device_map` in `load_model_dict_into_meta` (#10851)

* `device_map` in `load_model_dict_into_meta`

* _LOW_CPU_MEM_USAGE_DEFAULT

* fix is_peft_version is_bitsandbytes_version
parent b27d4edb
...@@ -17,7 +17,7 @@ from ..models.embeddings import ( ...@@ -17,7 +17,7 @@ from ..models.embeddings import (
ImageProjection, ImageProjection,
MultiIPAdapterImageProjection, MultiIPAdapterImageProjection,
) )
from ..models.modeling_utils import load_model_dict_into_meta from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta
from ..utils import ( from ..utils import (
is_accelerate_available, is_accelerate_available,
is_torch_version, is_torch_version,
...@@ -36,7 +36,7 @@ class FluxTransformer2DLoadersMixin: ...@@ -36,7 +36,7 @@ class FluxTransformer2DLoadersMixin:
Load layers into a [`FluxTransformer2DModel`]. Load layers into a [`FluxTransformer2DModel`].
""" """
def _convert_ip_adapter_image_proj_to_diffusers(self, state_dict, low_cpu_mem_usage=False): def _convert_ip_adapter_image_proj_to_diffusers(self, state_dict, low_cpu_mem_usage=_LOW_CPU_MEM_USAGE_DEFAULT):
if low_cpu_mem_usage: if low_cpu_mem_usage:
if is_accelerate_available(): if is_accelerate_available():
from accelerate import init_empty_weights from accelerate import init_empty_weights
...@@ -82,11 +82,12 @@ class FluxTransformer2DLoadersMixin: ...@@ -82,11 +82,12 @@ class FluxTransformer2DLoadersMixin:
if not low_cpu_mem_usage: if not low_cpu_mem_usage:
image_projection.load_state_dict(updated_state_dict, strict=True) image_projection.load_state_dict(updated_state_dict, strict=True)
else: else:
load_model_dict_into_meta(image_projection, updated_state_dict, device=self.device, dtype=self.dtype) device_map = {"": self.device}
load_model_dict_into_meta(image_projection, updated_state_dict, device_map=device_map, dtype=self.dtype)
return image_projection return image_projection
def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=False): def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=_LOW_CPU_MEM_USAGE_DEFAULT):
from ..models.attention_processor import ( from ..models.attention_processor import (
FluxIPAdapterJointAttnProcessor2_0, FluxIPAdapterJointAttnProcessor2_0,
) )
...@@ -151,15 +152,15 @@ class FluxTransformer2DLoadersMixin: ...@@ -151,15 +152,15 @@ class FluxTransformer2DLoadersMixin:
if not low_cpu_mem_usage: if not low_cpu_mem_usage:
attn_procs[name].load_state_dict(value_dict) attn_procs[name].load_state_dict(value_dict)
else: else:
device = self.device device_map = {"": self.device}
dtype = self.dtype dtype = self.dtype
load_model_dict_into_meta(attn_procs[name], value_dict, device=device, dtype=dtype) load_model_dict_into_meta(attn_procs[name], value_dict, device_map=device_map, dtype=dtype)
key_id += 1 key_id += 1
return attn_procs return attn_procs
def _load_ip_adapter_weights(self, state_dicts, low_cpu_mem_usage=False): def _load_ip_adapter_weights(self, state_dicts, low_cpu_mem_usage=_LOW_CPU_MEM_USAGE_DEFAULT):
if not isinstance(state_dicts, list): if not isinstance(state_dicts, list):
state_dicts = [state_dicts] state_dicts = [state_dicts]
......
...@@ -75,8 +75,9 @@ class SD3Transformer2DLoadersMixin: ...@@ -75,8 +75,9 @@ class SD3Transformer2DLoadersMixin:
if not low_cpu_mem_usage: if not low_cpu_mem_usage:
attn_procs[name].load_state_dict(layer_state_dict[idx], strict=True) attn_procs[name].load_state_dict(layer_state_dict[idx], strict=True)
else: else:
device_map = {"": self.device}
load_model_dict_into_meta( load_model_dict_into_meta(
attn_procs[name], layer_state_dict[idx], device=self.device, dtype=self.dtype attn_procs[name], layer_state_dict[idx], device_map=device_map, dtype=self.dtype
) )
return attn_procs return attn_procs
...@@ -144,7 +145,8 @@ class SD3Transformer2DLoadersMixin: ...@@ -144,7 +145,8 @@ class SD3Transformer2DLoadersMixin:
if not low_cpu_mem_usage: if not low_cpu_mem_usage:
image_proj.load_state_dict(updated_state_dict, strict=True) image_proj.load_state_dict(updated_state_dict, strict=True)
else: else:
load_model_dict_into_meta(image_proj, updated_state_dict, device=self.device, dtype=self.dtype) device_map = {"": self.device}
load_model_dict_into_meta(image_proj, updated_state_dict, device_map=device_map, dtype=self.dtype)
return image_proj return image_proj
......
...@@ -30,7 +30,7 @@ from ..models.embeddings import ( ...@@ -30,7 +30,7 @@ from ..models.embeddings import (
IPAdapterPlusImageProjection, IPAdapterPlusImageProjection,
MultiIPAdapterImageProjection, MultiIPAdapterImageProjection,
) )
from ..models.modeling_utils import load_model_dict_into_meta, load_state_dict from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta, load_state_dict
from ..utils import ( from ..utils import (
USE_PEFT_BACKEND, USE_PEFT_BACKEND,
_get_model_file, _get_model_file,
...@@ -143,7 +143,7 @@ class UNet2DConditionLoadersMixin: ...@@ -143,7 +143,7 @@ class UNet2DConditionLoadersMixin:
adapter_name = kwargs.pop("adapter_name", None) adapter_name = kwargs.pop("adapter_name", None)
_pipeline = kwargs.pop("_pipeline", None) _pipeline = kwargs.pop("_pipeline", None)
network_alphas = kwargs.pop("network_alphas", None) network_alphas = kwargs.pop("network_alphas", None)
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", False) low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
allow_pickle = False allow_pickle = False
if low_cpu_mem_usage and is_peft_version("<=", "0.13.0"): if low_cpu_mem_usage and is_peft_version("<=", "0.13.0"):
...@@ -540,7 +540,7 @@ class UNet2DConditionLoadersMixin: ...@@ -540,7 +540,7 @@ class UNet2DConditionLoadersMixin:
return state_dict return state_dict
def _convert_ip_adapter_image_proj_to_diffusers(self, state_dict, low_cpu_mem_usage=False): def _convert_ip_adapter_image_proj_to_diffusers(self, state_dict, low_cpu_mem_usage=_LOW_CPU_MEM_USAGE_DEFAULT):
if low_cpu_mem_usage: if low_cpu_mem_usage:
if is_accelerate_available(): if is_accelerate_available():
from accelerate import init_empty_weights from accelerate import init_empty_weights
...@@ -753,11 +753,12 @@ class UNet2DConditionLoadersMixin: ...@@ -753,11 +753,12 @@ class UNet2DConditionLoadersMixin:
if not low_cpu_mem_usage: if not low_cpu_mem_usage:
image_projection.load_state_dict(updated_state_dict, strict=True) image_projection.load_state_dict(updated_state_dict, strict=True)
else: else:
load_model_dict_into_meta(image_projection, updated_state_dict, device=self.device, dtype=self.dtype) device_map = {"": self.device}
load_model_dict_into_meta(image_projection, updated_state_dict, device_map=device_map, dtype=self.dtype)
return image_projection return image_projection
def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=False): def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=_LOW_CPU_MEM_USAGE_DEFAULT):
from ..models.attention_processor import ( from ..models.attention_processor import (
IPAdapterAttnProcessor, IPAdapterAttnProcessor,
IPAdapterAttnProcessor2_0, IPAdapterAttnProcessor2_0,
...@@ -846,13 +847,14 @@ class UNet2DConditionLoadersMixin: ...@@ -846,13 +847,14 @@ class UNet2DConditionLoadersMixin:
else: else:
device = next(iter(value_dict.values())).device device = next(iter(value_dict.values())).device
dtype = next(iter(value_dict.values())).dtype dtype = next(iter(value_dict.values())).dtype
load_model_dict_into_meta(attn_procs[name], value_dict, device=device, dtype=dtype) device_map = {"": device}
load_model_dict_into_meta(attn_procs[name], value_dict, device_map=device_map, dtype=dtype)
key_id += 2 key_id += 2
return attn_procs return attn_procs
def _load_ip_adapter_weights(self, state_dicts, low_cpu_mem_usage=False): def _load_ip_adapter_weights(self, state_dicts, low_cpu_mem_usage=_LOW_CPU_MEM_USAGE_DEFAULT):
if not isinstance(state_dicts, list): if not isinstance(state_dicts, list):
state_dicts = [state_dicts] state_dicts = [state_dicts]
......
...@@ -815,7 +815,7 @@ def is_peft_version(operation: str, version: str): ...@@ -815,7 +815,7 @@ def is_peft_version(operation: str, version: str):
version (`str`): version (`str`):
A version string A version string
""" """
if not _peft_version: if not _peft_available:
return False return False
return compare_versions(parse(_peft_version), operation, version) return compare_versions(parse(_peft_version), operation, version)
...@@ -829,7 +829,7 @@ def is_bitsandbytes_version(operation: str, version: str): ...@@ -829,7 +829,7 @@ def is_bitsandbytes_version(operation: str, version: str):
version (`str`): version (`str`):
A version string A version string
""" """
if not _bitsandbytes_version: if not _bitsandbytes_available:
return False return False
return compare_versions(parse(_bitsandbytes_version), operation, version) return compare_versions(parse(_bitsandbytes_version), operation, version)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment