Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
e347725d
"INSTALL/git@developer.sourcefind.cn:dadigang/Ventoy.git" did not exist on "0b7fa630a477204220c87eac8a4005ee9f205d64"
Commit
e347725d
authored
Dec 17, 2019
by
Morgan Funtowicz
Browse files
More fine-grained control over pipeline creation with config argument.
parent
55397dfb
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
17 additions
and
8 deletions
+17
-8
transformers/pipelines.py
transformers/pipelines.py
+17
-8
No files found.
transformers/pipelines.py
View file @
e347725d
...
...
@@ -497,7 +497,7 @@ class QuestionAnsweringPipeline(Pipeline):
'score'
:
score
.
item
(),
'start'
:
np
.
where
(
char_to_word
==
feature
.
token_to_orig_map
[
s
])[
0
][
0
].
item
(),
'end'
:
np
.
where
(
char_to_word
==
feature
.
token_to_orig_map
[
e
])[
0
][
-
1
].
item
(),
'answer'
:
' '
.
join
(
example
.
doc_tokens
[
feature
.
token_to_orig_map
[
s
]:
feature
.
token_to_orig_map
[
e
]
+
1
])
'answer'
:
' '
.
join
(
example
.
doc_tokens
[
feature
.
token_to_orig_map
[
s
]:
feature
.
token_to_orig_map
[
e
]
+
1
])
}
for
s
,
e
,
score
in
zip
(
starts
,
ends
,
scores
)
]
...
...
@@ -612,7 +612,8 @@ SUPPORTED_TASKS = {
}
def
pipeline
(
task
:
str
,
model
,
config
:
Optional
[
PretrainedConfig
]
=
None
,
tokenizer
:
Optional
[
Union
[
str
,
PreTrainedTokenizer
]]
=
None
,
**
kwargs
)
->
Pipeline
:
def
pipeline
(
task
:
str
,
model
,
config
:
Optional
[
Union
[
str
,
PretrainedConfig
]]
=
None
,
tokenizer
:
Optional
[
Union
[
str
,
PreTrainedTokenizer
]]
=
None
,
**
kwargs
)
->
Pipeline
:
"""
Utility factory method to build a pipeline.
Pipeline are made of:
...
...
@@ -637,13 +638,21 @@ def pipeline(task: str, model, config: Optional[PretrainedConfig] = None, tokeni
task
,
allocator
=
targeted_task
[
'impl'
],
targeted_task
[
'tf'
]
if
is_tf_available
()
else
targeted_task
[
'pt'
]
# Special handling for model conversion
from_tf
=
model
.
endswith
(
'.h5'
)
and
not
is_tf_available
()
from_pt
=
model
.
endswith
(
'.bin'
)
and
not
is_torch_available
()
if
isinstance
(
model
,
str
):
from_tf
=
model
.
endswith
(
'.h5'
)
and
not
is_tf_available
()
from_pt
=
model
.
endswith
(
'.bin'
)
and
not
is_torch_available
()
if
from_tf
:
logger
.
warning
(
'Model might be a TensorFlow model (ending with `.h5`) but TensorFlow is not available. '
'Trying to load the model with PyTorch.'
)
elif
from_pt
:
logger
.
warning
(
'Model might be a PyTorch model (ending with `.bin`) but PyTorch is not available. '
'Trying to load the model with Tensorflow.'
)
else
:
from_tf
=
from_pt
=
False
if
from_tf
:
logger
.
warning
(
'Model might be a TensorFlow model (ending with `.h5`) but TensorFlow is not available. Trying to load the model with PyTorch.'
)
elif
from_pt
:
logger
.
warning
(
'Model might be a PyTorch model (ending with `.bin`) but PyTorch is not available. Trying to load the model with Tensorflow.'
)
if
isinstance
(
config
,
str
):
config
=
PretrainedConfig
.
from_pretrained
(
config
)
if
allocator
.
__name__
.
startswith
(
'TF'
):
model
=
allocator
.
from_pretrained
(
model
,
config
=
config
,
from_pt
=
from_pt
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment