Unverified Commit b73c7383 authored by Aryan's avatar Aryan Committed by GitHub
Browse files

Remove device synchronization when loading weights (#11927)

* update

* make style
parent 06fd4277
......@@ -24,7 +24,7 @@ from typing_extensions import Self
from .. import __version__
from ..quantizers import DiffusersAutoQuantizer
from ..utils import deprecate, is_accelerate_available, logging
from ..utils.torch_utils import device_synchronize, empty_device_cache
from ..utils.torch_utils import empty_device_cache
from .single_file_utils import (
SingleFileComponentError,
convert_animatediff_checkpoint_to_diffusers,
......@@ -431,10 +431,7 @@ class FromOriginalModelMixin:
keep_in_fp32_modules=keep_in_fp32_modules,
unexpected_keys=unexpected_keys,
)
# Ensure tensors are correctly placed on device by synchronizing before returning control to user. This is
# required because we move tensors with non_blocking=True, which is slightly faster for model loading.
empty_device_cache()
device_synchronize()
else:
_, unexpected_keys = model.load_state_dict(diffusers_format_checkpoint, strict=False)
......
......@@ -46,7 +46,7 @@ from ..utils import (
)
from ..utils.constants import DIFFUSERS_REQUEST_TIMEOUT
from ..utils.hub_utils import _get_model_file
from ..utils.torch_utils import device_synchronize, empty_device_cache
from ..utils.torch_utils import empty_device_cache
if is_transformers_available():
......@@ -1690,10 +1690,7 @@ def create_diffusers_clip_model_from_ldm(
if is_accelerate_available():
load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype)
# Ensure tensors are correctly placed on device by synchronizing before returning control to user. This is
# required because we move tensors with non_blocking=True, which is slightly faster for model loading.
empty_device_cache()
device_synchronize()
else:
model.load_state_dict(diffusers_format_checkpoint, strict=False)
......@@ -2153,10 +2150,7 @@ def create_diffusers_t5_model_from_checkpoint(
if is_accelerate_available():
load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype)
# Ensure tensors are correctly placed on device by synchronizing before returning control to user. This is
# required because we move tensors with non_blocking=True, which is slightly faster for model loading.
empty_device_cache()
device_synchronize()
else:
model.load_state_dict(diffusers_format_checkpoint)
......
......@@ -19,7 +19,7 @@ from ..models.embeddings import (
)
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta
from ..utils import is_accelerate_available, is_torch_version, logging
from ..utils.torch_utils import device_synchronize, empty_device_cache
from ..utils.torch_utils import empty_device_cache
if is_accelerate_available():
......@@ -82,7 +82,6 @@ class FluxTransformer2DLoadersMixin:
device_map = {"": self.device}
load_model_dict_into_meta(image_projection, updated_state_dict, device_map=device_map, dtype=self.dtype)
empty_device_cache()
device_synchronize()
return image_projection
......@@ -158,7 +157,6 @@ class FluxTransformer2DLoadersMixin:
key_id += 1
empty_device_cache()
device_synchronize()
return attn_procs
......
......@@ -18,7 +18,7 @@ from ..models.attention_processor import SD3IPAdapterJointAttnProcessor2_0
from ..models.embeddings import IPAdapterTimeImageProjection
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta
from ..utils import is_accelerate_available, is_torch_version, logging
from ..utils.torch_utils import device_synchronize, empty_device_cache
from ..utils.torch_utils import empty_device_cache
logger = logging.get_logger(__name__)
......@@ -82,7 +82,6 @@ class SD3Transformer2DLoadersMixin:
)
empty_device_cache()
device_synchronize()
return attn_procs
......@@ -152,7 +151,6 @@ class SD3Transformer2DLoadersMixin:
device_map = {"": self.device}
load_model_dict_into_meta(image_proj, updated_state_dict, device_map=device_map, dtype=self.dtype)
empty_device_cache()
device_synchronize()
return image_proj
......
......@@ -43,7 +43,7 @@ from ..utils import (
is_torch_version,
logging,
)
from ..utils.torch_utils import device_synchronize, empty_device_cache
from ..utils.torch_utils import empty_device_cache
from .lora_base import _func_optionally_disable_offloading
from .lora_pipeline import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE, TEXT_ENCODER_NAME, UNET_NAME
from .utils import AttnProcsLayers
......@@ -755,7 +755,6 @@ class UNet2DConditionLoadersMixin:
device_map = {"": self.device}
load_model_dict_into_meta(image_projection, updated_state_dict, device_map=device_map, dtype=self.dtype)
empty_device_cache()
device_synchronize()
return image_projection
......@@ -854,7 +853,6 @@ class UNet2DConditionLoadersMixin:
key_id += 2
empty_device_cache()
device_synchronize()
return attn_procs
......
......@@ -62,7 +62,7 @@ from ..utils.hub_utils import (
load_or_create_model_card,
populate_model_card,
)
from ..utils.torch_utils import device_synchronize, empty_device_cache
from ..utils.torch_utils import empty_device_cache
from .model_loading_utils import (
_caching_allocator_warmup,
_determine_device_map,
......@@ -1540,10 +1540,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
assign_to_params_buffers = check_support_param_buffer_assignment(model, state_dict)
error_msgs += _load_state_dict_into_model(model, state_dict, assign_to_params_buffers)
# Ensure tensors are correctly placed on device by synchronizing before returning control to user. This is
# required because we move tensors with non_blocking=True, which is slightly faster for model loading.
empty_device_cache()
device_synchronize()
if offload_index is not None and len(offload_index) > 0:
save_offload_index(offload_index, offload_folder)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment