Unverified Commit b9e99654 authored by Dhruv Nair's avatar Dhruv Nair Committed by GitHub
Browse files

[Modular] Updates for Custom Pipeline Blocks (#11940)

* update

* update

* update
parent 478df933
...@@ -323,6 +323,7 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin): ...@@ -323,6 +323,7 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin):
""" """
config_name = "config.json" config_name = "config.json"
model_name = None
@classmethod @classmethod
def _get_signature_keys(cls, obj): def _get_signature_keys(cls, obj):
...@@ -333,6 +334,14 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin): ...@@ -333,6 +334,14 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin):
return expected_modules, optional_parameters return expected_modules, optional_parameters
@property
def expected_components(self) -> List[ComponentSpec]:
return []
@property
def expected_configs(self) -> List[ConfigSpec]:
return []
@classmethod @classmethod
def from_pretrained( def from_pretrained(
cls, cls,
...@@ -358,7 +367,9 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin): ...@@ -358,7 +367,9 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin):
trust_remote_code, pretrained_model_name_or_path, has_remote_code trust_remote_code, pretrained_model_name_or_path, has_remote_code
) )
if not (has_remote_code and trust_remote_code): if not (has_remote_code and trust_remote_code):
raise ValueError("TODO") raise ValueError(
"Selected model repository does not happear to have any custom code or does not have a valid `config.json` file."
)
class_ref = config["auto_map"][cls.__name__] class_ref = config["auto_map"][cls.__name__]
module_file, class_name = class_ref.split(".") module_file, class_name = class_ref.split(".")
...@@ -367,7 +378,6 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin): ...@@ -367,7 +378,6 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin):
pretrained_model_name_or_path, pretrained_model_name_or_path,
module_file=module_file, module_file=module_file,
class_name=class_name, class_name=class_name,
is_modular=True,
**hub_kwargs, **hub_kwargs,
**kwargs, **kwargs,
) )
......
...@@ -93,7 +93,7 @@ class ComponentSpec: ...@@ -93,7 +93,7 @@ class ComponentSpec:
config: Optional[FrozenDict] = None config: Optional[FrozenDict] = None
# YiYi Notes: should we change it to pretrained_model_name_or_path for consistency? a bit long for a field name # YiYi Notes: should we change it to pretrained_model_name_or_path for consistency? a bit long for a field name
repo: Optional[Union[str, List[str]]] = field(default=None, metadata={"loading": True}) repo: Optional[Union[str, List[str]]] = field(default=None, metadata={"loading": True})
subfolder: Optional[str] = field(default=None, metadata={"loading": True}) subfolder: Optional[str] = field(default="", metadata={"loading": True})
variant: Optional[str] = field(default=None, metadata={"loading": True}) variant: Optional[str] = field(default=None, metadata={"loading": True})
revision: Optional[str] = field(default=None, metadata={"loading": True}) revision: Optional[str] = field(default=None, metadata={"loading": True})
default_creation_method: Literal["from_config", "from_pretrained"] = "from_pretrained" default_creation_method: Literal["from_config", "from_pretrained"] = "from_pretrained"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment