Unverified Commit d54622c2 authored by Dhruv Nair's avatar Dhruv Nair Committed by GitHub
Browse files

[Modular] Allow custom blocks to be saved to `local_dir` (#12381)



update
Co-authored-by: default avatarYiYi Xu <yixu310@gmail.com>
parent df8dd778
......@@ -305,6 +305,7 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin):
"cache_dir",
"force_download",
"local_files_only",
"local_dir",
"proxies",
"resume_download",
"revision",
......@@ -331,7 +332,6 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin):
module_file=module_file,
class_name=class_name,
**hub_kwargs,
**kwargs,
)
expected_kwargs, optional_kwargs = block_cls._get_signature_keys(block_cls)
block_kwargs = {
......
......@@ -254,6 +254,7 @@ def get_cached_module_file(
token: Optional[Union[bool, str]] = None,
revision: Optional[str] = None,
local_files_only: bool = False,
local_dir: Optional[str] = None,
):
"""
Prepares Downloads a module from a local folder or a distant repo and returns its path inside the cached
......@@ -332,6 +333,7 @@ def get_cached_module_file(
force_download=force_download,
proxies=proxies,
local_files_only=local_files_only,
local_dir=local_dir,
)
submodule = "git"
module_file = pretrained_model_name_or_path + ".py"
......@@ -355,6 +357,7 @@ def get_cached_module_file(
force_download=force_download,
proxies=proxies,
local_files_only=local_files_only,
local_dir=local_dir,
token=token,
)
submodule = os.path.join("local", "--".join(pretrained_model_name_or_path.split("/")))
......@@ -415,6 +418,7 @@ def get_cached_module_file(
token=token,
revision=revision,
local_files_only=local_files_only,
local_dir=local_dir,
)
return os.path.join(full_submodule, module_file)
......@@ -431,7 +435,7 @@ def get_class_from_dynamic_module(
token: Optional[Union[bool, str]] = None,
revision: Optional[str] = None,
local_files_only: bool = False,
**kwargs,
local_dir: Optional[str] = None,
):
"""
Extracts a class from a module file, present in the local folder or repository of a model.
......@@ -496,5 +500,6 @@ def get_class_from_dynamic_module(
token=token,
revision=revision,
local_files_only=local_files_only,
local_dir=local_dir,
)
return get_class_in_module(class_name, final_module)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment