"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "d4dbd7ca59bd50dd034e7995cb36e5efed3d9512"
Unverified Commit 52d516c3 authored by Lucain's avatar Lucain Committed by GitHub
Browse files

Minor fixes in transformers-tools (#23364)

* Few fixes in new Tools implementation

* code quality
parent 728c5e82
...@@ -23,8 +23,8 @@ import os ...@@ -23,8 +23,8 @@ import os
import tempfile import tempfile
from typing import Any, Dict, List, Optional, Union from typing import Any, Dict, List, Optional, Union
from huggingface_hub import CommitOperationAdd, HfFolder, create_commit, create_repo, hf_hub_download, metadata_update from huggingface_hub import create_repo, hf_hub_download, metadata_update, upload_folder
from huggingface_hub.utils import RepositoryNotFoundError, get_session from huggingface_hub.utils import RepositoryNotFoundError, build_hf_headers, get_session
from ..dynamic_module_utils import custom_object_save, get_class_from_dynamic_module, get_imports from ..dynamic_module_utils import custom_object_save, get_class_from_dynamic_module, get_imports
from ..image_utils import is_pil_image from ..image_utils import is_pil_image
...@@ -173,7 +173,14 @@ class Tool: ...@@ -173,7 +173,14 @@ class Tool:
f.write("\n".join(imports) + "\n") f.write("\n".join(imports) + "\n")
@classmethod @classmethod
def from_hub(cls, repo_id, model_repo_id=None, token=None, remote=False, **kwargs): def from_hub(
cls,
repo_id: str,
model_repo_id: Optional[str] = None,
token: Optional[str] = None,
remote: bool = False,
**kwargs,
):
""" """
Loads a tool defined on the Hub. Loads a tool defined on the Hub.
...@@ -285,22 +292,17 @@ class Tool: ...@@ -285,22 +292,17 @@ class Tool:
repo_url = create_repo( repo_url = create_repo(
repo_id=repo_id, token=token, private=private, exist_ok=True, repo_type="space", space_sdk="gradio" repo_id=repo_id, token=token, private=private, exist_ok=True, repo_type="space", space_sdk="gradio"
) )
metadata_update(repo_id, {"tags": ["tool"]}, repo_type="space")
repo_id = repo_url.repo_id repo_id = repo_url.repo_id
metadata_update(repo_id, {"tags": ["tool"]}, repo_type="space")
with tempfile.TemporaryDirectory() as work_dir: with tempfile.TemporaryDirectory() as work_dir:
# Save all files. # Save all files.
self.save(work_dir) self.save(work_dir)
os.listdir(work_dir)
operations = [
CommitOperationAdd(path_or_fileobj=os.path.join(work_dir, f), path_in_repo=f)
for f in os.listdir(work_dir)
]
logger.info(f"Uploading the following files to {repo_id}: {','.join(os.listdir(work_dir))}") logger.info(f"Uploading the following files to {repo_id}: {','.join(os.listdir(work_dir))}")
return create_commit( return upload_folder(
repo_id=repo_id, repo_id=repo_id,
operations=operations,
commit_message=commit_message, commit_message=commit_message,
folder_path=work_dir,
token=token, token=token,
create_pr=create_pr, create_pr=create_pr,
repo_type="space", repo_type="space",
...@@ -482,7 +484,7 @@ class PipelineTool(Tool): ...@@ -482,7 +484,7 @@ class PipelineTool(Tool):
self.hub_kwargs = hub_kwargs self.hub_kwargs = hub_kwargs
self.hub_kwargs["use_auth_token"] = token self.hub_kwargs["use_auth_token"] = token
self.is_initialized = False super().__init__()
def setup(self): def setup(self):
""" """
...@@ -508,6 +510,8 @@ class PipelineTool(Tool): ...@@ -508,6 +510,8 @@ class PipelineTool(Tool):
if self.device_map is None: if self.device_map is None:
self.model.to(self.device) self.model.to(self.device)
super().setup()
def encode(self, raw_inputs): def encode(self, raw_inputs):
""" """
Uses the `pre_processor` to prepare the inputs for the `model`. Uses the `pre_processor` to prepare the inputs for the `model`.
...@@ -674,9 +678,7 @@ def add_description(description): ...@@ -674,9 +678,7 @@ def add_description(description):
## Will move to the Hub ## Will move to the Hub
class EndpointClient: class EndpointClient:
def __init__(self, endpoint_url: str, token: Optional[str] = None): def __init__(self, endpoint_url: str, token: Optional[str] = None):
if token is None: self.headers = {**build_hf_headers(token=token), "Content-Type": "application/json"}
token = HfFolder().get_token()
self.headers = {"authorization": f"Bearer {token}", "Content-Type": "application/json"}
self.endpoint_url = endpoint_url self.endpoint_url = endpoint_url
@staticmethod @staticmethod
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment