Unverified Commit 37cb819d authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Lora] Speed up lora loading (#4994)

* speed up lora loading

* Apply suggestions from code review

* up

* up

* Fix more

* Correct more

* Apply suggestions from code review

* up

* Fix more

* Fix more -

* up

* up
parent f64d52db
This diff is collapsed.
...@@ -128,6 +128,31 @@ def load_state_dict(checkpoint_file: Union[str, os.PathLike], variant: Optional[ ...@@ -128,6 +128,31 @@ def load_state_dict(checkpoint_file: Union[str, os.PathLike], variant: Optional[
) )
def load_model_dict_into_meta(model, state_dict, device=None, dtype=None, model_name_or_path=None):
device = device or torch.device("cpu")
dtype = dtype or torch.float32
unexpected_keys = []
empty_state_dict = model.state_dict()
for param_name, param in state_dict.items():
if param_name not in empty_state_dict:
unexpected_keys.append(param_name)
continue
if empty_state_dict[param_name].shape != param.shape:
model_name_or_path_str = f"{model_name_or_path} " if model_name_or_path is not None else ""
raise ValueError(
f"Cannot load {model_name_or_path_str}because {param_name} expected shape {empty_state_dict[param_name]}, but got {param.shape}. If you want to instead overwrite randomly initialized weights, please make sure to pass both `low_cpu_mem_usage=False` and `ignore_mismatched_sizes=True`. For more information, see also: https://github.com/huggingface/diffusers/issues/1619#issuecomment-1345604389 as an example."
)
accepts_dtype = "dtype" in set(inspect.signature(set_module_tensor_to_device).parameters.keys())
if accepts_dtype:
set_module_tensor_to_device(model, param_name, device, value=param, dtype=dtype)
else:
set_module_tensor_to_device(model, param_name, device, value=param)
return unexpected_keys
def _load_state_dict_into_model(model_to_load, state_dict): def _load_state_dict_into_model(model_to_load, state_dict):
# Convert old format to new format if needed from a PyTorch state_dict # Convert old format to new format if needed from a PyTorch state_dict
# copy state_dict so _load_from_state_dict can modify it # copy state_dict so _load_from_state_dict can modify it
...@@ -624,29 +649,14 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): ...@@ -624,29 +649,14 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
" `low_cpu_mem_usage=False` and `device_map=None` if you want to randomly initialize" " `low_cpu_mem_usage=False` and `device_map=None` if you want to randomly initialize"
" those weights or else make sure your checkpoint file is correct." " those weights or else make sure your checkpoint file is correct."
) )
unexpected_keys = []
empty_state_dict = model.state_dict()
for param_name, param in state_dict.items():
accepts_dtype = "dtype" in set(
inspect.signature(set_module_tensor_to_device).parameters.keys()
)
if param_name not in empty_state_dict:
unexpected_keys.append(param_name)
continue
if empty_state_dict[param_name].shape != param.shape:
raise ValueError(
f"Cannot load {pretrained_model_name_or_path} because {param_name} expected shape {empty_state_dict[param_name]}, but got {param.shape}. If you want to instead overwrite randomly initialized weights, please make sure to pass both `low_cpu_mem_usage=False` and `ignore_mismatched_sizes=True`. For more information, see also: https://github.com/huggingface/diffusers/issues/1619#issuecomment-1345604389 as an example."
)
if accepts_dtype: unexpected_keys = load_model_dict_into_meta(
set_module_tensor_to_device( model,
model, param_name, param_device, value=param, dtype=torch_dtype state_dict,
) device=param_device,
else: dtype=torch_dtype,
set_module_tensor_to_device(model, param_name, param_device, value=param) model_name_or_path=pretrained_model_name_or_path,
)
if cls._keys_to_ignore_on_load_unexpected is not None: if cls._keys_to_ignore_on_load_unexpected is not None:
for pat in cls._keys_to_ignore_on_load_unexpected: for pat in cls._keys_to_ignore_on_load_unexpected:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment