Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
52d516c3
Unverified
Commit
52d516c3
authored
May 16, 2023
by
Lucain
Committed by
GitHub
May 16, 2023
Browse files
Minor fixes in transformers-tools (#23364)
* Few fixes in new Tools implementation * code quality
parent
728c5e82
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
17 additions
and
15 deletions
+17
-15
src/transformers/tools/base.py
src/transformers/tools/base.py
+17
-15
No files found.
src/transformers/tools/base.py
View file @
52d516c3
...
...
@@ -23,8 +23,8 @@ import os
import
tempfile
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Union
from
huggingface_hub
import
CommitOperationAdd
,
HfFolder
,
create_commit
,
create_repo
,
hf_hub_download
,
metadata_update
from
huggingface_hub.utils
import
RepositoryNotFoundError
,
get_session
from
huggingface_hub
import
create_repo
,
hf_hub_download
,
metadata_update
,
upload_folder
from
huggingface_hub.utils
import
RepositoryNotFoundError
,
build_hf_headers
,
get_session
from
..dynamic_module_utils
import
custom_object_save
,
get_class_from_dynamic_module
,
get_imports
from
..image_utils
import
is_pil_image
...
...
@@ -173,7 +173,14 @@ class Tool:
f
.
write
(
"
\n
"
.
join
(
imports
)
+
"
\n
"
)
@
classmethod
def
from_hub
(
cls
,
repo_id
,
model_repo_id
=
None
,
token
=
None
,
remote
=
False
,
**
kwargs
):
def
from_hub
(
cls
,
repo_id
:
str
,
model_repo_id
:
Optional
[
str
]
=
None
,
token
:
Optional
[
str
]
=
None
,
remote
:
bool
=
False
,
**
kwargs
,
):
"""
Loads a tool defined on the Hub.
...
...
@@ -285,22 +292,17 @@ class Tool:
repo_url
=
create_repo
(
repo_id
=
repo_id
,
token
=
token
,
private
=
private
,
exist_ok
=
True
,
repo_type
=
"space"
,
space_sdk
=
"gradio"
)
metadata_update
(
repo_id
,
{
"tags"
:
[
"tool"
]},
repo_type
=
"space"
)
repo_id
=
repo_url
.
repo_id
metadata_update
(
repo_id
,
{
"tags"
:
[
"tool"
]},
repo_type
=
"space"
)
with
tempfile
.
TemporaryDirectory
()
as
work_dir
:
# Save all files.
self
.
save
(
work_dir
)
os
.
listdir
(
work_dir
)
operations
=
[
CommitOperationAdd
(
path_or_fileobj
=
os
.
path
.
join
(
work_dir
,
f
),
path_in_repo
=
f
)
for
f
in
os
.
listdir
(
work_dir
)
]
logger
.
info
(
f
"Uploading the following files to
{
repo_id
}
:
{
','
.
join
(
os
.
listdir
(
work_dir
))
}
"
)
return
create_commit
(
return
upload_folder
(
repo_id
=
repo_id
,
operations
=
operations
,
commit_message
=
commit_message
,
folder_path
=
work_dir
,
token
=
token
,
create_pr
=
create_pr
,
repo_type
=
"space"
,
...
...
@@ -482,7 +484,7 @@ class PipelineTool(Tool):
self
.
hub_kwargs
=
hub_kwargs
self
.
hub_kwargs
[
"use_auth_token"
]
=
token
s
elf
.
is_initialized
=
False
s
uper
().
__init__
()
def
setup
(
self
):
"""
...
...
@@ -508,6 +510,8 @@ class PipelineTool(Tool):
if
self
.
device_map
is
None
:
self
.
model
.
to
(
self
.
device
)
super
().
setup
()
def
encode
(
self
,
raw_inputs
):
"""
Uses the `pre_processor` to prepare the inputs for the `model`.
...
...
@@ -674,9 +678,7 @@ def add_description(description):
## Will move to the Hub
class
EndpointClient
:
def
__init__
(
self
,
endpoint_url
:
str
,
token
:
Optional
[
str
]
=
None
):
if
token
is
None
:
token
=
HfFolder
().
get_token
()
self
.
headers
=
{
"authorization"
:
f
"Bearer
{
token
}
"
,
"Content-Type"
:
"application/json"
}
self
.
headers
=
{
**
build_hf_headers
(
token
=
token
),
"Content-Type"
:
"application/json"
}
self
.
endpoint_url
=
endpoint_url
@
staticmethod
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment