Unverified Commit b73c7383 authored by Aryan's avatar Aryan Committed by GitHub
Browse files

Remove device synchronization when loading weights (#11927)

* update

* make style
parent 06fd4277
...@@ -24,7 +24,7 @@ from typing_extensions import Self ...@@ -24,7 +24,7 @@ from typing_extensions import Self
from .. import __version__ from .. import __version__
from ..quantizers import DiffusersAutoQuantizer from ..quantizers import DiffusersAutoQuantizer
from ..utils import deprecate, is_accelerate_available, logging from ..utils import deprecate, is_accelerate_available, logging
from ..utils.torch_utils import device_synchronize, empty_device_cache from ..utils.torch_utils import empty_device_cache
from .single_file_utils import ( from .single_file_utils import (
SingleFileComponentError, SingleFileComponentError,
convert_animatediff_checkpoint_to_diffusers, convert_animatediff_checkpoint_to_diffusers,
...@@ -431,10 +431,7 @@ class FromOriginalModelMixin: ...@@ -431,10 +431,7 @@ class FromOriginalModelMixin:
keep_in_fp32_modules=keep_in_fp32_modules, keep_in_fp32_modules=keep_in_fp32_modules,
unexpected_keys=unexpected_keys, unexpected_keys=unexpected_keys,
) )
# Ensure tensors are correctly placed on device by synchronizing before returning control to user. This is
# required because we move tensors with non_blocking=True, which is slightly faster for model loading.
empty_device_cache() empty_device_cache()
device_synchronize()
else: else:
_, unexpected_keys = model.load_state_dict(diffusers_format_checkpoint, strict=False) _, unexpected_keys = model.load_state_dict(diffusers_format_checkpoint, strict=False)
......
...@@ -46,7 +46,7 @@ from ..utils import ( ...@@ -46,7 +46,7 @@ from ..utils import (
) )
from ..utils.constants import DIFFUSERS_REQUEST_TIMEOUT from ..utils.constants import DIFFUSERS_REQUEST_TIMEOUT
from ..utils.hub_utils import _get_model_file from ..utils.hub_utils import _get_model_file
from ..utils.torch_utils import device_synchronize, empty_device_cache from ..utils.torch_utils import empty_device_cache
if is_transformers_available(): if is_transformers_available():
...@@ -1690,10 +1690,7 @@ def create_diffusers_clip_model_from_ldm( ...@@ -1690,10 +1690,7 @@ def create_diffusers_clip_model_from_ldm(
if is_accelerate_available(): if is_accelerate_available():
load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype) load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype)
# Ensure tensors are correctly placed on device by synchronizing before returning control to user. This is
# required because we move tensors with non_blocking=True, which is slightly faster for model loading.
empty_device_cache() empty_device_cache()
device_synchronize()
else: else:
model.load_state_dict(diffusers_format_checkpoint, strict=False) model.load_state_dict(diffusers_format_checkpoint, strict=False)
...@@ -2153,10 +2150,7 @@ def create_diffusers_t5_model_from_checkpoint( ...@@ -2153,10 +2150,7 @@ def create_diffusers_t5_model_from_checkpoint(
if is_accelerate_available(): if is_accelerate_available():
load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype) load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype)
# Ensure tensors are correctly placed on device by synchronizing before returning control to user. This is
# required because we move tensors with non_blocking=True, which is slightly faster for model loading.
empty_device_cache() empty_device_cache()
device_synchronize()
else: else:
model.load_state_dict(diffusers_format_checkpoint) model.load_state_dict(diffusers_format_checkpoint)
......
...@@ -19,7 +19,7 @@ from ..models.embeddings import ( ...@@ -19,7 +19,7 @@ from ..models.embeddings import (
) )
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta
from ..utils import is_accelerate_available, is_torch_version, logging from ..utils import is_accelerate_available, is_torch_version, logging
from ..utils.torch_utils import device_synchronize, empty_device_cache from ..utils.torch_utils import empty_device_cache
if is_accelerate_available(): if is_accelerate_available():
...@@ -82,7 +82,6 @@ class FluxTransformer2DLoadersMixin: ...@@ -82,7 +82,6 @@ class FluxTransformer2DLoadersMixin:
device_map = {"": self.device} device_map = {"": self.device}
load_model_dict_into_meta(image_projection, updated_state_dict, device_map=device_map, dtype=self.dtype) load_model_dict_into_meta(image_projection, updated_state_dict, device_map=device_map, dtype=self.dtype)
empty_device_cache() empty_device_cache()
device_synchronize()
return image_projection return image_projection
...@@ -158,7 +157,6 @@ class FluxTransformer2DLoadersMixin: ...@@ -158,7 +157,6 @@ class FluxTransformer2DLoadersMixin:
key_id += 1 key_id += 1
empty_device_cache() empty_device_cache()
device_synchronize()
return attn_procs return attn_procs
......
...@@ -18,7 +18,7 @@ from ..models.attention_processor import SD3IPAdapterJointAttnProcessor2_0 ...@@ -18,7 +18,7 @@ from ..models.attention_processor import SD3IPAdapterJointAttnProcessor2_0
from ..models.embeddings import IPAdapterTimeImageProjection from ..models.embeddings import IPAdapterTimeImageProjection
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta
from ..utils import is_accelerate_available, is_torch_version, logging from ..utils import is_accelerate_available, is_torch_version, logging
from ..utils.torch_utils import device_synchronize, empty_device_cache from ..utils.torch_utils import empty_device_cache
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
...@@ -82,7 +82,6 @@ class SD3Transformer2DLoadersMixin: ...@@ -82,7 +82,6 @@ class SD3Transformer2DLoadersMixin:
) )
empty_device_cache() empty_device_cache()
device_synchronize()
return attn_procs return attn_procs
...@@ -152,7 +151,6 @@ class SD3Transformer2DLoadersMixin: ...@@ -152,7 +151,6 @@ class SD3Transformer2DLoadersMixin:
device_map = {"": self.device} device_map = {"": self.device}
load_model_dict_into_meta(image_proj, updated_state_dict, device_map=device_map, dtype=self.dtype) load_model_dict_into_meta(image_proj, updated_state_dict, device_map=device_map, dtype=self.dtype)
empty_device_cache() empty_device_cache()
device_synchronize()
return image_proj return image_proj
......
...@@ -43,7 +43,7 @@ from ..utils import ( ...@@ -43,7 +43,7 @@ from ..utils import (
is_torch_version, is_torch_version,
logging, logging,
) )
from ..utils.torch_utils import device_synchronize, empty_device_cache from ..utils.torch_utils import empty_device_cache
from .lora_base import _func_optionally_disable_offloading from .lora_base import _func_optionally_disable_offloading
from .lora_pipeline import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE, TEXT_ENCODER_NAME, UNET_NAME from .lora_pipeline import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE, TEXT_ENCODER_NAME, UNET_NAME
from .utils import AttnProcsLayers from .utils import AttnProcsLayers
...@@ -755,7 +755,6 @@ class UNet2DConditionLoadersMixin: ...@@ -755,7 +755,6 @@ class UNet2DConditionLoadersMixin:
device_map = {"": self.device} device_map = {"": self.device}
load_model_dict_into_meta(image_projection, updated_state_dict, device_map=device_map, dtype=self.dtype) load_model_dict_into_meta(image_projection, updated_state_dict, device_map=device_map, dtype=self.dtype)
empty_device_cache() empty_device_cache()
device_synchronize()
return image_projection return image_projection
...@@ -854,7 +853,6 @@ class UNet2DConditionLoadersMixin: ...@@ -854,7 +853,6 @@ class UNet2DConditionLoadersMixin:
key_id += 2 key_id += 2
empty_device_cache() empty_device_cache()
device_synchronize()
return attn_procs return attn_procs
......
...@@ -62,7 +62,7 @@ from ..utils.hub_utils import ( ...@@ -62,7 +62,7 @@ from ..utils.hub_utils import (
load_or_create_model_card, load_or_create_model_card,
populate_model_card, populate_model_card,
) )
from ..utils.torch_utils import device_synchronize, empty_device_cache from ..utils.torch_utils import empty_device_cache
from .model_loading_utils import ( from .model_loading_utils import (
_caching_allocator_warmup, _caching_allocator_warmup,
_determine_device_map, _determine_device_map,
...@@ -1540,10 +1540,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): ...@@ -1540,10 +1540,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
assign_to_params_buffers = check_support_param_buffer_assignment(model, state_dict) assign_to_params_buffers = check_support_param_buffer_assignment(model, state_dict)
error_msgs += _load_state_dict_into_model(model, state_dict, assign_to_params_buffers) error_msgs += _load_state_dict_into_model(model, state_dict, assign_to_params_buffers)
# Ensure tensors are correctly placed on device by synchronizing before returning control to user. This is
# required because we move tensors with non_blocking=True, which is slightly faster for model loading.
empty_device_cache() empty_device_cache()
device_synchronize()
if offload_index is not None and len(offload_index) > 0: if offload_index is not None and len(offload_index) > 0:
save_offload_index(offload_index, offload_folder) save_offload_index(offload_index, offload_folder)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment