Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
52d516c3
Unverified
Commit
52d516c3
authored
May 16, 2023
by
Lucain
Committed by
GitHub
May 16, 2023
Browse files
Minor fixes in transformers-tools (#23364)
* Few fixes in new Tools implementation * code quality
parent
728c5e82
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
17 additions
and
15 deletions
+17
-15
src/transformers/tools/base.py
src/transformers/tools/base.py
+17
-15
No files found.
src/transformers/tools/base.py
View file @
52d516c3
...
@@ -23,8 +23,8 @@ import os
...
@@ -23,8 +23,8 @@ import os
import
tempfile
import
tempfile
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Union
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Union
from
huggingface_hub
import
CommitOperationAdd
,
HfFolder
,
create_commit
,
create_repo
,
hf_hub_download
,
metadata_update
from
huggingface_hub
import
create_repo
,
hf_hub_download
,
metadata_update
,
upload_folder
from
huggingface_hub.utils
import
RepositoryNotFoundError
,
get_session
from
huggingface_hub.utils
import
RepositoryNotFoundError
,
build_hf_headers
,
get_session
from
..dynamic_module_utils
import
custom_object_save
,
get_class_from_dynamic_module
,
get_imports
from
..dynamic_module_utils
import
custom_object_save
,
get_class_from_dynamic_module
,
get_imports
from
..image_utils
import
is_pil_image
from
..image_utils
import
is_pil_image
...
@@ -173,7 +173,14 @@ class Tool:
...
@@ -173,7 +173,14 @@ class Tool:
f
.
write
(
"
\n
"
.
join
(
imports
)
+
"
\n
"
)
f
.
write
(
"
\n
"
.
join
(
imports
)
+
"
\n
"
)
@
classmethod
@
classmethod
def
from_hub
(
cls
,
repo_id
,
model_repo_id
=
None
,
token
=
None
,
remote
=
False
,
**
kwargs
):
def
from_hub
(
cls
,
repo_id
:
str
,
model_repo_id
:
Optional
[
str
]
=
None
,
token
:
Optional
[
str
]
=
None
,
remote
:
bool
=
False
,
**
kwargs
,
):
"""
"""
Loads a tool defined on the Hub.
Loads a tool defined on the Hub.
...
@@ -285,22 +292,17 @@ class Tool:
...
@@ -285,22 +292,17 @@ class Tool:
repo_url
=
create_repo
(
repo_url
=
create_repo
(
repo_id
=
repo_id
,
token
=
token
,
private
=
private
,
exist_ok
=
True
,
repo_type
=
"space"
,
space_sdk
=
"gradio"
repo_id
=
repo_id
,
token
=
token
,
private
=
private
,
exist_ok
=
True
,
repo_type
=
"space"
,
space_sdk
=
"gradio"
)
)
metadata_update
(
repo_id
,
{
"tags"
:
[
"tool"
]},
repo_type
=
"space"
)
repo_id
=
repo_url
.
repo_id
repo_id
=
repo_url
.
repo_id
metadata_update
(
repo_id
,
{
"tags"
:
[
"tool"
]},
repo_type
=
"space"
)
with
tempfile
.
TemporaryDirectory
()
as
work_dir
:
with
tempfile
.
TemporaryDirectory
()
as
work_dir
:
# Save all files.
# Save all files.
self
.
save
(
work_dir
)
self
.
save
(
work_dir
)
os
.
listdir
(
work_dir
)
operations
=
[
CommitOperationAdd
(
path_or_fileobj
=
os
.
path
.
join
(
work_dir
,
f
),
path_in_repo
=
f
)
for
f
in
os
.
listdir
(
work_dir
)
]
logger
.
info
(
f
"Uploading the following files to
{
repo_id
}
:
{
','
.
join
(
os
.
listdir
(
work_dir
))
}
"
)
logger
.
info
(
f
"Uploading the following files to
{
repo_id
}
:
{
','
.
join
(
os
.
listdir
(
work_dir
))
}
"
)
return
create_commit
(
return
upload_folder
(
repo_id
=
repo_id
,
repo_id
=
repo_id
,
operations
=
operations
,
commit_message
=
commit_message
,
commit_message
=
commit_message
,
folder_path
=
work_dir
,
token
=
token
,
token
=
token
,
create_pr
=
create_pr
,
create_pr
=
create_pr
,
repo_type
=
"space"
,
repo_type
=
"space"
,
...
@@ -482,7 +484,7 @@ class PipelineTool(Tool):
...
@@ -482,7 +484,7 @@ class PipelineTool(Tool):
self
.
hub_kwargs
=
hub_kwargs
self
.
hub_kwargs
=
hub_kwargs
self
.
hub_kwargs
[
"use_auth_token"
]
=
token
self
.
hub_kwargs
[
"use_auth_token"
]
=
token
s
elf
.
is_initialized
=
False
s
uper
().
__init__
()
def
setup
(
self
):
def
setup
(
self
):
"""
"""
...
@@ -508,6 +510,8 @@ class PipelineTool(Tool):
...
@@ -508,6 +510,8 @@ class PipelineTool(Tool):
if
self
.
device_map
is
None
:
if
self
.
device_map
is
None
:
self
.
model
.
to
(
self
.
device
)
self
.
model
.
to
(
self
.
device
)
super
().
setup
()
def
encode
(
self
,
raw_inputs
):
def
encode
(
self
,
raw_inputs
):
"""
"""
Uses the `pre_processor` to prepare the inputs for the `model`.
Uses the `pre_processor` to prepare the inputs for the `model`.
...
@@ -674,9 +678,7 @@ def add_description(description):
...
@@ -674,9 +678,7 @@ def add_description(description):
## Will move to the Hub
## Will move to the Hub
class
EndpointClient
:
class
EndpointClient
:
def
__init__
(
self
,
endpoint_url
:
str
,
token
:
Optional
[
str
]
=
None
):
def
__init__
(
self
,
endpoint_url
:
str
,
token
:
Optional
[
str
]
=
None
):
if
token
is
None
:
self
.
headers
=
{
**
build_hf_headers
(
token
=
token
),
"Content-Type"
:
"application/json"
}
token
=
HfFolder
().
get_token
()
self
.
headers
=
{
"authorization"
:
f
"Bearer
{
token
}
"
,
"Content-Type"
:
"application/json"
}
self
.
endpoint_url
=
endpoint_url
self
.
endpoint_url
=
endpoint_url
@
staticmethod
@
staticmethod
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment