Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
ec6878f6
Unverified
Commit
ec6878f6
authored
Nov 03, 2022
by
Nicolas Patry
Committed by
GitHub
Nov 03, 2022
Browse files
Now supporting pathlike in pipelines too. (#20030)
parent
aa39967b
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
12 additions
and
0 deletions
+12
-0
src/transformers/pipelines/__init__.py
src/transformers/pipelines/__init__.py
+3
-0
tests/pipelines/test_pipelines_common.py
tests/pipelines/test_pipelines_common.py
+9
-0
No files found.
src/transformers/pipelines/__init__.py
View file @
ec6878f6
...
@@ -21,6 +21,7 @@ import os
...
@@ -21,6 +21,7 @@ import os
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
import
warnings
import
warnings
from
pathlib
import
Path
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
,
Tuple
,
Union
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
,
Tuple
,
Union
from
numpy
import
isin
from
numpy
import
isin
...
@@ -638,6 +639,8 @@ def pipeline(
...
@@ -638,6 +639,8 @@ def pipeline(
" feature_extractor may not be compatible with the default model. Please provide a PreTrainedModel class"
" feature_extractor may not be compatible with the default model. Please provide a PreTrainedModel class"
" or a path/identifier to a pretrained model when providing feature_extractor."
" or a path/identifier to a pretrained model when providing feature_extractor."
)
)
if
isinstance
(
model
,
Path
):
model
=
str
(
model
)
# Config is the primordial information item.
# Config is the primordial information item.
# Instantiate config if needed
# Instantiate config if needed
...
...
tests/pipelines/test_pipelines_common.py
View file @
ec6878f6
...
@@ -356,6 +356,15 @@ class CommonPipelineTest(unittest.TestCase):
...
@@ -356,6 +356,15 @@ class CommonPipelineTest(unittest.TestCase):
self
.
assertEqual
(
pipe
.
_batch_size
,
2
)
self
.
assertEqual
(
pipe
.
_batch_size
,
2
)
self
.
assertEqual
(
pipe
.
_num_workers
,
1
)
self
.
assertEqual
(
pipe
.
_num_workers
,
1
)
@
require_torch
def
test_pipeline_pathlike
(
self
):
pipe
=
pipeline
(
model
=
"hf-internal-testing/tiny-random-distilbert"
)
with
tempfile
.
TemporaryDirectory
()
as
d
:
pipe
.
save_pretrained
(
d
)
path
=
Path
(
d
)
newpipe
=
pipeline
(
task
=
"text-classification"
,
model
=
path
)
self
.
assertIsInstance
(
newpipe
,
TextClassificationPipeline
)
@
require_torch
@
require_torch
def
test_pipeline_override
(
self
):
def
test_pipeline_override
(
self
):
class
MyPipeline
(
TextClassificationPipeline
):
class
MyPipeline
(
TextClassificationPipeline
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment