"docs/source/ja/tasks/object_detection.md" did not exist on "3df3b9d4bf006ab193b3c1257f3436b9fdb91759"
Unverified Commit 52d516c3 authored by Lucain's avatar Lucain Committed by GitHub
Browse files

Minor fixes in transformers-tools (#23364)

* Few fixes in new Tools implementation

* code quality
parent 728c5e82
......@@ -23,8 +23,8 @@ import os
import tempfile
from typing import Any, Dict, List, Optional, Union
from huggingface_hub import CommitOperationAdd, HfFolder, create_commit, create_repo, hf_hub_download, metadata_update
from huggingface_hub.utils import RepositoryNotFoundError, get_session
from huggingface_hub import create_repo, hf_hub_download, metadata_update, upload_folder
from huggingface_hub.utils import RepositoryNotFoundError, build_hf_headers, get_session
from ..dynamic_module_utils import custom_object_save, get_class_from_dynamic_module, get_imports
from ..image_utils import is_pil_image
......@@ -173,7 +173,14 @@ class Tool:
f.write("\n".join(imports) + "\n")
@classmethod
def from_hub(cls, repo_id, model_repo_id=None, token=None, remote=False, **kwargs):
def from_hub(
cls,
repo_id: str,
model_repo_id: Optional[str] = None,
token: Optional[str] = None,
remote: bool = False,
**kwargs,
):
"""
Loads a tool defined on the Hub.
......@@ -285,22 +292,17 @@ class Tool:
repo_url = create_repo(
repo_id=repo_id, token=token, private=private, exist_ok=True, repo_type="space", space_sdk="gradio"
)
metadata_update(repo_id, {"tags": ["tool"]}, repo_type="space")
repo_id = repo_url.repo_id
metadata_update(repo_id, {"tags": ["tool"]}, repo_type="space")
with tempfile.TemporaryDirectory() as work_dir:
# Save all files.
self.save(work_dir)
os.listdir(work_dir)
operations = [
CommitOperationAdd(path_or_fileobj=os.path.join(work_dir, f), path_in_repo=f)
for f in os.listdir(work_dir)
]
logger.info(f"Uploading the following files to {repo_id}: {','.join(os.listdir(work_dir))}")
return create_commit(
return upload_folder(
repo_id=repo_id,
operations=operations,
commit_message=commit_message,
folder_path=work_dir,
token=token,
create_pr=create_pr,
repo_type="space",
......@@ -482,7 +484,7 @@ class PipelineTool(Tool):
self.hub_kwargs = hub_kwargs
self.hub_kwargs["use_auth_token"] = token
self.is_initialized = False
super().__init__()
def setup(self):
"""
......@@ -508,6 +510,8 @@ class PipelineTool(Tool):
if self.device_map is None:
self.model.to(self.device)
super().setup()
def encode(self, raw_inputs):
"""
Uses the `pre_processor` to prepare the inputs for the `model`.
......@@ -674,9 +678,7 @@ def add_description(description):
## Will move to the Hub
class EndpointClient:
def __init__(self, endpoint_url: str, token: Optional[str] = None):
if token is None:
token = HfFolder().get_token()
self.headers = {"authorization": f"Bearer {token}", "Content-Type": "application/json"}
self.headers = {**build_hf_headers(token=token), "Content-Type": "application/json"}
self.endpoint_url = endpoint_url
@staticmethod
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment