Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
ec6878f6
"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "7ccac73f749ce535851b9188f3867d5ed87c318c"
Unverified
Commit
ec6878f6
authored
Nov 03, 2022
by
Nicolas Patry
Committed by
GitHub
Nov 03, 2022
Browse files
Now supporting pathlike in pipelines too. (#20030)
parent
aa39967b
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
12 additions
and
0 deletions
+12
-0
src/transformers/pipelines/__init__.py
src/transformers/pipelines/__init__.py
+3
-0
tests/pipelines/test_pipelines_common.py
tests/pipelines/test_pipelines_common.py
+9
-0
No files found.
src/transformers/pipelines/__init__.py
View file @
ec6878f6
...
@@ -21,6 +21,7 @@ import os
...
@@ -21,6 +21,7 @@ import os
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
import
warnings
import
warnings
from
pathlib
import
Path
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
,
Tuple
,
Union
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
,
Tuple
,
Union
from
numpy
import
isin
from
numpy
import
isin
...
@@ -638,6 +639,8 @@ def pipeline(
...
@@ -638,6 +639,8 @@ def pipeline(
" feature_extractor may not be compatible with the default model. Please provide a PreTrainedModel class"
" feature_extractor may not be compatible with the default model. Please provide a PreTrainedModel class"
" or a path/identifier to a pretrained model when providing feature_extractor."
" or a path/identifier to a pretrained model when providing feature_extractor."
)
)
if
isinstance
(
model
,
Path
):
model
=
str
(
model
)
# Config is the primordial information item.
# Config is the primordial information item.
# Instantiate config if needed
# Instantiate config if needed
...
...
tests/pipelines/test_pipelines_common.py
View file @
ec6878f6
...
@@ -356,6 +356,15 @@ class CommonPipelineTest(unittest.TestCase):
...
@@ -356,6 +356,15 @@ class CommonPipelineTest(unittest.TestCase):
self
.
assertEqual
(
pipe
.
_batch_size
,
2
)
self
.
assertEqual
(
pipe
.
_batch_size
,
2
)
self
.
assertEqual
(
pipe
.
_num_workers
,
1
)
self
.
assertEqual
(
pipe
.
_num_workers
,
1
)
@
require_torch
def
test_pipeline_pathlike
(
self
):
pipe
=
pipeline
(
model
=
"hf-internal-testing/tiny-random-distilbert"
)
with
tempfile
.
TemporaryDirectory
()
as
d
:
pipe
.
save_pretrained
(
d
)
path
=
Path
(
d
)
newpipe
=
pipeline
(
task
=
"text-classification"
,
model
=
path
)
self
.
assertIsInstance
(
newpipe
,
TextClassificationPipeline
)
@
require_torch
@
require_torch
def
test_pipeline_override
(
self
):
def
test_pipeline_override
(
self
):
class
MyPipeline
(
TextClassificationPipeline
):
class
MyPipeline
(
TextClassificationPipeline
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment