"docs/vscode:/vscode.git/clone" did not exist on "2d846263d626b349772bc53315ae7deac1ece844"
Unverified Commit 6ddbf622 authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[Transformer2DModel] Handle `norm_type` safely while remapping (#8370)



* handle norm_type of transformer2d_model safely.

* log an info when old model class is being returned.

* Apply suggestions from code review
Co-authored-by: default avatarDhruv Nair <dhruv.nair@gmail.com>

* remove extra stuff

---------
Co-authored-by: default avatarDhruv Nair <dhruv.nair@gmail.com>
parent 3ff39e8e
......@@ -71,18 +71,22 @@ def _determine_device_map(model: torch.nn.Module, device_map, max_memory, torch_
def _fetch_remapped_cls_from_config(config, old_class):
previous_class_name = old_class.__name__
remapped_class_name = _CLASS_REMAPPING_DICT.get(previous_class_name).get(config["norm_type"])
# load diffusers library to import compatible and original scheduler
diffusers_library = importlib.import_module(__name__.split(".")[0])
remapped_class = getattr(diffusers_library, remapped_class_name)
logger.info(
f"Changing class object to be of `{remapped_class_name}` type from `{previous_class_name}` type."
"This is because `previous_class_name` is scheduled to be deprecated in a future version. Note that this"
" DOESN'T affect the final results."
)
return remapped_class
remapped_class_name = _CLASS_REMAPPING_DICT.get(previous_class_name).get(config["norm_type"], None)
# Details:
# https://github.com/huggingface/diffusers/pull/7647#discussion_r1621344818
if remapped_class_name:
# load diffusers library to import compatible and original scheduler
diffusers_library = importlib.import_module(__name__.split(".")[0])
remapped_class = getattr(diffusers_library, remapped_class_name)
logger.info(
f"Changing class object to be of `{remapped_class_name}` type from `{previous_class_name}` type."
f"This is because `{previous_class_name}` is scheduled to be deprecated in a future version. Note that this"
" DOESN'T affect the final results."
)
return remapped_class
else:
return old_class
def load_state_dict(checkpoint_file: Union[str, os.PathLike], variant: Optional[str] = None):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment