"src/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "6b09f370c4184a89276c6891d17f45b9c8e8b4e5"
Unverified Commit cee1cd6e authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Remote code] Add functionality to run remote models, schedulers, pipelines (#5472)



* upload custom remote poc

* up

* make style

* finish

* better name

* Apply suggestions from code review

* Update tests/pipelines/test_pipelines.py

* more fixes

* remove ipdb

* more fixes

* fix more

* finish tests

---------
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
parent 5b448a5e
...@@ -485,10 +485,18 @@ class ConfigMixin: ...@@ -485,10 +485,18 @@ class ConfigMixin:
# remove attributes from orig class that cannot be expected # remove attributes from orig class that cannot be expected
orig_cls_name = config_dict.pop("_class_name", cls.__name__) orig_cls_name = config_dict.pop("_class_name", cls.__name__)
if orig_cls_name != cls.__name__ and hasattr(diffusers_library, orig_cls_name): if (
isinstance(orig_cls_name, str)
and orig_cls_name != cls.__name__
and hasattr(diffusers_library, orig_cls_name)
):
orig_cls = getattr(diffusers_library, orig_cls_name) orig_cls = getattr(diffusers_library, orig_cls_name)
unexpected_keys_from_orig = cls._get_init_keys(orig_cls) - expected_keys unexpected_keys_from_orig = cls._get_init_keys(orig_cls) - expected_keys
config_dict = {k: v for k, v in config_dict.items() if k not in unexpected_keys_from_orig} config_dict = {k: v for k, v in config_dict.items() if k not in unexpected_keys_from_orig}
elif not isinstance(orig_cls_name, str) and not isinstance(orig_cls_name, (list, tuple)):
raise ValueError(
"Make sure that the `_class_name` is of type string or list of string (for custom pipelines)."
)
# remove private attributes # remove private attributes
config_dict = {k: v for k, v in config_dict.items() if not k.startswith("_")} config_dict = {k: v for k, v in config_dict.items() if not k.startswith("_")}
......
...@@ -33,8 +33,6 @@ from packaging import version ...@@ -33,8 +33,6 @@ from packaging import version
from requests.exceptions import HTTPError from requests.exceptions import HTTPError
from tqdm.auto import tqdm from tqdm.auto import tqdm
import diffusers
from .. import __version__ from .. import __version__
from ..configuration_utils import ConfigMixin from ..configuration_utils import ConfigMixin
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT
...@@ -305,13 +303,23 @@ def maybe_raise_or_warn( ...@@ -305,13 +303,23 @@ def maybe_raise_or_warn(
) )
def get_class_obj_and_candidates(library_name, class_name, importable_classes, pipelines, is_pipeline_module): def get_class_obj_and_candidates(
library_name, class_name, importable_classes, pipelines, is_pipeline_module, component_name=None, cache_dir=None
):
"""Simple helper method to retrieve class object of module as well as potential parent class objects""" """Simple helper method to retrieve class object of module as well as potential parent class objects"""
component_folder = os.path.join(cache_dir, component_name)
if is_pipeline_module: if is_pipeline_module:
pipeline_module = getattr(pipelines, library_name) pipeline_module = getattr(pipelines, library_name)
class_obj = getattr(pipeline_module, class_name) class_obj = getattr(pipeline_module, class_name)
class_candidates = {c: class_obj for c in importable_classes.keys()} class_candidates = {c: class_obj for c in importable_classes.keys()}
elif os.path.isfile(os.path.join(component_folder, library_name + ".py")):
# load custom component
class_obj = get_class_from_dynamic_module(
component_folder, module_file=library_name + ".py", class_name=class_name
)
class_candidates = {c: class_obj for c in importable_classes.keys()}
else: else:
# else we just import it from the library. # else we just import it from the library.
library = importlib.import_module(library_name) library = importlib.import_module(library_name)
...@@ -323,7 +331,15 @@ def get_class_obj_and_candidates(library_name, class_name, importable_classes, p ...@@ -323,7 +331,15 @@ def get_class_obj_and_candidates(library_name, class_name, importable_classes, p
def _get_pipeline_class( def _get_pipeline_class(
class_obj, config, load_connected_pipeline=False, custom_pipeline=None, cache_dir=None, revision=None class_obj,
config,
load_connected_pipeline=False,
custom_pipeline=None,
repo_id=None,
hub_revision=None,
class_name=None,
cache_dir=None,
revision=None,
): ):
if custom_pipeline is not None: if custom_pipeline is not None:
if custom_pipeline.endswith(".py"): if custom_pipeline.endswith(".py"):
...@@ -331,11 +347,19 @@ def _get_pipeline_class( ...@@ -331,11 +347,19 @@ def _get_pipeline_class(
# decompose into folder & file # decompose into folder & file
file_name = path.name file_name = path.name
custom_pipeline = path.parent.absolute() custom_pipeline = path.parent.absolute()
elif repo_id is not None:
file_name = f"{custom_pipeline}.py"
custom_pipeline = repo_id
else: else:
file_name = CUSTOM_PIPELINE_FILE_NAME file_name = CUSTOM_PIPELINE_FILE_NAME
return get_class_from_dynamic_module( return get_class_from_dynamic_module(
custom_pipeline, module_file=file_name, cache_dir=cache_dir, revision=revision custom_pipeline,
module_file=file_name,
class_name=class_name,
repo_id=repo_id,
cache_dir=cache_dir,
revision=revision if hub_revision is None else hub_revision,
) )
if class_obj != DiffusionPipeline: if class_obj != DiffusionPipeline:
...@@ -383,11 +407,18 @@ def load_sub_model( ...@@ -383,11 +407,18 @@ def load_sub_model(
variant: str, variant: str,
low_cpu_mem_usage: bool, low_cpu_mem_usage: bool,
cached_folder: Union[str, os.PathLike], cached_folder: Union[str, os.PathLike],
revision: str = None,
): ):
"""Helper method to load the module `name` from `library_name` and `class_name`""" """Helper method to load the module `name` from `library_name` and `class_name`"""
# retrieve class candidates # retrieve class candidates
class_obj, class_candidates = get_class_obj_and_candidates( class_obj, class_candidates = get_class_obj_and_candidates(
library_name, class_name, importable_classes, pipelines, is_pipeline_module library_name,
class_name,
importable_classes,
pipelines,
is_pipeline_module,
component_name=name,
cache_dir=cached_folder,
) )
load_method_name = None load_method_name = None
...@@ -414,14 +445,15 @@ def load_sub_model( ...@@ -414,14 +445,15 @@ def load_sub_model(
load_method = getattr(class_obj, load_method_name) load_method = getattr(class_obj, load_method_name)
# add kwargs to loading method # add kwargs to loading method
diffusers_module = importlib.import_module(__name__.split(".")[0])
loading_kwargs = {} loading_kwargs = {}
if issubclass(class_obj, torch.nn.Module): if issubclass(class_obj, torch.nn.Module):
loading_kwargs["torch_dtype"] = torch_dtype loading_kwargs["torch_dtype"] = torch_dtype
if issubclass(class_obj, diffusers.OnnxRuntimeModel): if issubclass(class_obj, diffusers_module.OnnxRuntimeModel):
loading_kwargs["provider"] = provider loading_kwargs["provider"] = provider
loading_kwargs["sess_options"] = sess_options loading_kwargs["sess_options"] = sess_options
is_diffusers_model = issubclass(class_obj, diffusers.ModelMixin) is_diffusers_model = issubclass(class_obj, diffusers_module.ModelMixin)
if is_transformers_available(): if is_transformers_available():
transformers_version = version.parse(version.parse(transformers.__version__).base_version) transformers_version = version.parse(version.parse(transformers.__version__).base_version)
...@@ -501,7 +533,8 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin): ...@@ -501,7 +533,8 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
def register_modules(self, **kwargs): def register_modules(self, **kwargs):
# import it here to avoid circular import # import it here to avoid circular import
from diffusers import pipelines diffusers_module = importlib.import_module(__name__.split(".")[0])
pipelines = getattr(diffusers_module, "pipelines")
for name, module in kwargs.items(): for name, module in kwargs.items():
# retrieve library # retrieve library
...@@ -1080,11 +1113,21 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin): ...@@ -1080,11 +1113,21 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
# 3. Load the pipeline class, if using custom module then load it from the hub # 3. Load the pipeline class, if using custom module then load it from the hub
# if we load from explicit class, let's use it # if we load from explicit class, let's use it
custom_class_name = None
if os.path.isfile(os.path.join(cached_folder, f"{custom_pipeline}.py")):
custom_pipeline = os.path.join(cached_folder, f"{custom_pipeline}.py")
elif isinstance(config_dict["_class_name"], (list, tuple)) and os.path.isfile(
os.path.join(cached_folder, f"{config_dict['_class_name'][0]}.py")
):
custom_pipeline = os.path.join(cached_folder, f"{config_dict['_class_name'][0]}.py")
custom_class_name = config_dict["_class_name"][1]
pipeline_class = _get_pipeline_class( pipeline_class = _get_pipeline_class(
cls, cls,
config_dict, config_dict,
load_connected_pipeline=load_connected_pipeline, load_connected_pipeline=load_connected_pipeline,
custom_pipeline=custom_pipeline, custom_pipeline=custom_pipeline,
class_name=custom_class_name,
cache_dir=cache_dir, cache_dir=cache_dir,
revision=custom_revision, revision=custom_revision,
) )
...@@ -1223,6 +1266,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin): ...@@ -1223,6 +1266,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
variant=variant, variant=variant,
low_cpu_mem_usage=low_cpu_mem_usage, low_cpu_mem_usage=low_cpu_mem_usage,
cached_folder=cached_folder, cached_folder=cached_folder,
revision=revision,
) )
logger.info( logger.info(
f"Loaded {name} as {class_name} from `{name}` subfolder of {pretrained_model_name_or_path}." f"Loaded {name} as {class_name} from `{name}` subfolder of {pretrained_model_name_or_path}."
...@@ -1542,6 +1586,10 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin): ...@@ -1542,6 +1586,10 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
will never be downloaded. By default `use_onnx` defaults to the `_is_onnx` class attribute which is will never be downloaded. By default `use_onnx` defaults to the `_is_onnx` class attribute which is
`False` for non-ONNX pipelines and `True` for ONNX pipelines. ONNX weights include both files ending `False` for non-ONNX pipelines and `True` for ONNX pipelines. ONNX weights include both files ending
with `.onnx` and `.pb`. with `.onnx` and `.pb`.
trust_remote_code (`bool`, *optional*, defaults to `False`):
Whether or not to allow for custom pipelines and components defined on the Hub in their own files. This
option should only be set to `True` for repositories you trust and in which you have read the code, as
it will execute code present on the Hub on your local machine.
Returns: Returns:
`os.PathLike`: `os.PathLike`:
...@@ -1569,6 +1617,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin): ...@@ -1569,6 +1617,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
use_safetensors = kwargs.pop("use_safetensors", None) use_safetensors = kwargs.pop("use_safetensors", None)
use_onnx = kwargs.pop("use_onnx", None) use_onnx = kwargs.pop("use_onnx", None)
load_connected_pipeline = kwargs.pop("load_connected_pipeline", False) load_connected_pipeline = kwargs.pop("load_connected_pipeline", False)
trust_remote_code = kwargs.pop("trust_remote_code", False)
allow_pickle = False allow_pickle = False
if use_safetensors is None: if use_safetensors is None:
...@@ -1604,15 +1653,34 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin): ...@@ -1604,15 +1653,34 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
) )
config_dict = cls._dict_from_json_file(config_file) config_dict = cls._dict_from_json_file(config_file)
ignore_filenames = config_dict.pop("_ignore_files", []) ignore_filenames = config_dict.pop("_ignore_files", [])
# retrieve all folder_names that contain relevant files # retrieve all folder_names that contain relevant files
folder_names = [k for k, v in config_dict.items() if isinstance(v, list)] folder_names = [k for k, v in config_dict.items() if isinstance(v, list) and k != "_class_name"]
filenames = {sibling.rfilename for sibling in info.siblings} filenames = {sibling.rfilename for sibling in info.siblings}
model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant) model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant)
diffusers_module = importlib.import_module(__name__.split(".")[0])
pipelines = getattr(diffusers_module, "pipelines")
# optionally create a custom component <> custom file mapping
custom_components = {}
for component in folder_names:
module_candidate = config_dict[component][0]
if module_candidate is None:
continue
candidate_file = os.path.join(component, module_candidate + ".py")
if candidate_file in filenames:
custom_components[component] = module_candidate
elif module_candidate not in LOADABLE_CLASSES and not hasattr(pipelines, module_candidate):
raise ValueError(
f"{candidate_file} as defined in `model_index.json` does not exist in {pretrained_model_name} and is not a module in 'diffusers/pipelines'."
)
if len(variant_filenames) == 0 and variant is not None: if len(variant_filenames) == 0 and variant is not None:
deprecation_message = ( deprecation_message = (
f"You are trying to load the model files of the `variant={variant}`, but no such modeling files are available." f"You are trying to load the model files of the `variant={variant}`, but no such modeling files are available."
...@@ -1636,12 +1704,21 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin): ...@@ -1636,12 +1704,21 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
model_folder_names = {os.path.split(f)[0] for f in model_filenames if os.path.split(f)[0] in folder_names} model_folder_names = {os.path.split(f)[0] for f in model_filenames if os.path.split(f)[0] in folder_names}
custom_class_name = None
if custom_pipeline is None and isinstance(config_dict["_class_name"], (list, tuple)):
custom_pipeline = config_dict["_class_name"][0]
custom_class_name = config_dict["_class_name"][1]
# all filenames compatible with variant will be added # all filenames compatible with variant will be added
allow_patterns = list(model_filenames) allow_patterns = list(model_filenames)
# allow all patterns from non-model folders # allow all patterns from non-model folders
# this enables downloading schedulers, tokenizers, ... # this enables downloading schedulers, tokenizers, ...
allow_patterns += [f"{k}/*" for k in folder_names if k not in model_folder_names] allow_patterns += [f"{k}/*" for k in folder_names if k not in model_folder_names]
# add custom component files
allow_patterns += [f"{k}/{f}.py" for k, f in custom_components.items()]
# add custom pipeline file
allow_patterns += [f"{custom_pipeline}.py"] if f"{custom_pipeline}.py" in filenames else []
# also allow downloading config.json files with the model # also allow downloading config.json files with the model
allow_patterns += [os.path.join(k, "config.json") for k in model_folder_names] allow_patterns += [os.path.join(k, "config.json") for k in model_folder_names]
...@@ -1652,12 +1729,32 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin): ...@@ -1652,12 +1729,32 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
CUSTOM_PIPELINE_FILE_NAME, CUSTOM_PIPELINE_FILE_NAME,
] ]
load_pipe_from_hub = custom_pipeline is not None and f"{custom_pipeline}.py" in filenames
load_components_from_hub = len(custom_components) > 0
if load_pipe_from_hub and not trust_remote_code:
raise ValueError(
f"The repository for {pretrained_model_name} contains custom code in {custom_pipeline}.py which must be executed to correctly "
f"load the model. You can inspect the repository content at https://hf.co/{pretrained_model_name}/blob/main/{custom_pipeline}.py.\n"
f"Please pass the argument `trust_remote_code=True` to allow custom code to be run."
)
if load_components_from_hub and not trust_remote_code:
raise ValueError(
f"The repository for {pretrained_model_name} contains custom code in {'.py, '.join([os.path.join(k, v) for k,v in custom_components.items()])} which must be executed to correctly "
f"load the model. You can inspect the repository content at {', '.join([f'https://hf.co/{pretrained_model_name}/{k}/{v}.py' for k,v in custom_components.items()])}.\n"
f"Please pass the argument `trust_remote_code=True` to allow custom code to be run."
)
# retrieve passed components that should not be downloaded # retrieve passed components that should not be downloaded
pipeline_class = _get_pipeline_class( pipeline_class = _get_pipeline_class(
cls, cls,
config_dict, config_dict,
load_connected_pipeline=load_connected_pipeline, load_connected_pipeline=load_connected_pipeline,
custom_pipeline=custom_pipeline, custom_pipeline=custom_pipeline,
repo_id=pretrained_model_name if load_pipe_from_hub else None,
hub_revision=revision,
class_name=custom_class_name,
cache_dir=cache_dir, cache_dir=cache_dir,
revision=custom_revision, revision=custom_revision,
) )
...@@ -1754,9 +1851,10 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin): ...@@ -1754,9 +1851,10 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
# retrieve pipeline class from local file # retrieve pipeline class from local file
cls_name = cls.load_config(os.path.join(cached_folder, "model_index.json")).get("_class_name", None) cls_name = cls.load_config(os.path.join(cached_folder, "model_index.json")).get("_class_name", None)
cls_name = cls_name[4:] if cls_name.startswith("Flax") else cls_name cls_name = cls_name[4:] if isinstance(cls_name, str) and cls_name.startswith("Flax") else cls_name
pipeline_class = getattr(diffusers, cls_name, None) diffusers_module = importlib.import_module(__name__.split(".")[0])
pipeline_class = getattr(diffusers_module, cls_name, None) if isinstance(cls_name, str) else None
if pipeline_class is not None and pipeline_class._load_connected_pipes: if pipeline_class is not None and pipeline_class._load_connected_pipes:
modelcard = ModelCard.load(os.path.join(cached_folder, "README.md")) modelcard = ModelCard.load(os.path.join(cached_folder, "README.md"))
......
...@@ -862,6 +862,58 @@ class CustomPipelineTests(unittest.TestCase): ...@@ -862,6 +862,58 @@ class CustomPipelineTests(unittest.TestCase):
# compare output to https://huggingface.co/hf-internal-testing/diffusers-dummy-pipeline/blob/main/pipeline.py#L102 # compare output to https://huggingface.co/hf-internal-testing/diffusers-dummy-pipeline/blob/main/pipeline.py#L102
assert output_str == "This is a test" assert output_str == "This is a test"
def test_remote_components(self):
# make sure that trust remote code has to be passed
with self.assertRaises(ValueError):
pipeline = DiffusionPipeline.from_pretrained("hf-internal-testing/tiny-sdxl-custom-components")
# Check that only loading custom componets "my_unet", "my_scheduler" works
pipeline = DiffusionPipeline.from_pretrained(
"hf-internal-testing/tiny-sdxl-custom-components", trust_remote_code=True
)
assert pipeline.config.unet == ("diffusers_modules.local.my_unet_model", "MyUNetModel")
assert pipeline.config.scheduler == ("diffusers_modules.local.my_scheduler", "MyScheduler")
assert pipeline.__class__.__name__ == "StableDiffusionXLPipeline"
pipeline = pipeline.to(torch_device)
images = pipeline("test", num_inference_steps=2, output_type="np")[0]
assert images.shape == (1, 64, 64, 3)
# Check that only loading custom componets "my_unet", "my_scheduler" and explicit custom pipeline works
pipeline = DiffusionPipeline.from_pretrained(
"hf-internal-testing/tiny-sdxl-custom-components", custom_pipeline="my_pipeline", trust_remote_code=True
)
assert pipeline.config.unet == ("diffusers_modules.local.my_unet_model", "MyUNetModel")
assert pipeline.config.scheduler == ("diffusers_modules.local.my_scheduler", "MyScheduler")
assert pipeline.__class__.__name__ == "MyPipeline"
pipeline = pipeline.to(torch_device)
images = pipeline("test", num_inference_steps=2, output_type="np")[0]
assert images.shape == (1, 64, 64, 3)
def test_remote_auto_custom_pipe(self):
# make sure that trust remote code has to be passed
with self.assertRaises(ValueError):
pipeline = DiffusionPipeline.from_pretrained("hf-internal-testing/tiny-sdxl-custom-all")
# Check that only loading custom componets "my_unet", "my_scheduler" and auto custom pipeline works
pipeline = DiffusionPipeline.from_pretrained(
"hf-internal-testing/tiny-sdxl-custom-all", trust_remote_code=True
)
assert pipeline.config.unet == ("diffusers_modules.local.my_unet_model", "MyUNetModel")
assert pipeline.config.scheduler == ("diffusers_modules.local.my_scheduler", "MyScheduler")
assert pipeline.__class__.__name__ == "MyPipeline"
pipeline = pipeline.to(torch_device)
images = pipeline("test", num_inference_steps=2, output_type="np")[0]
assert images.shape == (1, 64, 64, 3)
def test_local_custom_pipeline_repo(self): def test_local_custom_pipeline_repo(self):
local_custom_pipeline_path = get_tests_dir("fixtures/custom_pipeline") local_custom_pipeline_path = get_tests_dir("fixtures/custom_pipeline")
pipeline = DiffusionPipeline.from_pretrained( pipeline = DiffusionPipeline.from_pretrained(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment