Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
c2407fdd
Commit
c2407fdd
authored
Dec 09, 2019
by
Morgan Funtowicz
Browse files
Enable the Tensorflow backend.
parent
f116cf59
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
6 additions
and
10 deletions
+6
-10
transformers/pipelines.py
transformers/pipelines.py
+6
-10
No files found.
transformers/pipelines.py
View file @
c2407fdd
...
...
@@ -151,8 +151,7 @@ class QuestionAnsweringPipeline(Pipeline):
texts
=
[(
text
[
'question'
],
text
[
'context'
])
for
text
in
texts
]
inputs
=
self
.
tokenizer
.
batch_encode_plus
(
# texts, add_special_tokens=True, return_tensors='tf' if is_tf_available() else 'pt'
texts
,
add_special_tokens
=
True
,
return_tensors
=
'pt'
texts
,
add_special_tokens
=
True
,
return_tensors
=
'tf'
if
is_tf_available
()
else
'pt'
)
# Remove special_tokens_mask to avoid KeyError
...
...
@@ -161,10 +160,10 @@ class QuestionAnsweringPipeline(Pipeline):
# TODO : Harmonize model arguments across all model
inputs
[
'attention_mask'
]
=
inputs
.
pop
(
'encoder_attention_mask'
)
# if is_tf_available():
if
False
:
if
is_tf_available
():
# TODO trace model
start
,
end
=
self
.
model
(
inputs
)
start
,
end
=
start
.
numpy
(),
end
.
numpy
()
else
:
import
torch
with
torch
.
no_grad
():
...
...
@@ -204,9 +203,7 @@ class QuestionAnsweringPipeline(Pipeline):
# Remove candidate with end < start and end - start > max_answer_len
candidates
=
np
.
tril
(
np
.
triu
(
outer
),
max_answer_len
-
1
)
# start = np.max(candidates, axis=2).argmax(-1)
# end = np.max(candidates, axis=1).argmax(-1)
# Inspired by Chen & al. (https://github.com/facebookresearch/DrQA)
scores_flat
=
candidates
.
flatten
()
if
topk
==
1
:
idx_sort
=
[
np
.
argmax
(
scores_flat
)]
...
...
@@ -257,7 +254,7 @@ SUPPORTED_TASKS = {
},
'question-answering'
:
{
'impl'
:
QuestionAnsweringPipeline
,
#
'tf': TFAutoModelForQuestionAnswering if is_tf_available() else None,
'tf'
:
TFAutoModelForQuestionAnswering
if
is_tf_available
()
else
None
,
'pt'
:
AutoModelForQuestionAnswering
if
is_torch_available
()
else
None
}
}
...
...
@@ -280,8 +277,7 @@ def pipeline(task: str, model, tokenizer: Optional[Union[str, PreTrainedTokenize
raise
KeyError
(
"Unknown task {}, available tasks are {}"
.
format
(
task
,
list
(
SUPPORTED_TASKS
.
keys
())))
targeted_task
=
SUPPORTED_TASKS
[
task
]
# task, allocator = targeted_task['impl'], targeted_task['tf'] if is_tf_available() else targeted_task['pt']
task
,
allocator
=
targeted_task
[
'impl'
],
targeted_task
[
'pt'
]
task
,
allocator
=
targeted_task
[
'impl'
],
targeted_task
[
'tf'
]
if
is_tf_available
()
else
targeted_task
[
'pt'
]
model
=
allocator
.
from_pretrained
(
model
)
return
task
(
model
,
tokenizer
,
**
kwargs
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment