Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
a7d3794a
Commit
a7d3794a
authored
Dec 09, 2019
by
Morgan Funtowicz
Browse files
Remove token_type_ids for compatibility with DistilBert
parent
fe0f552e
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
4 additions
and
2 deletions
+4
-2
transformers/pipelines.py
transformers/pipelines.py
+4
-2
No files found.
transformers/pipelines.py
View file @
a7d3794a
...
@@ -20,7 +20,7 @@ from typing import Union, Optional, Tuple, List, Dict
...
@@ -20,7 +20,7 @@ from typing import Union, Optional, Tuple, List, Dict
import
numpy
as
np
import
numpy
as
np
from
transformers
import
is_tf_available
,
logger
,
AutoTokenizer
,
PreTrainedTokenizer
,
is_torch_available
from
transformers
import
is_tf_available
,
is_torch_available
,
logger
,
AutoTokenizer
,
PreTrainedTokenizer
if
is_tf_available
():
if
is_tf_available
():
from
transformers
import
TFAutoModelForSequenceClassification
,
TFAutoModelForQuestionAnswering
from
transformers
import
TFAutoModelForSequenceClassification
,
TFAutoModelForQuestionAnswering
...
@@ -154,6 +154,8 @@ class QuestionAnsweringPipeline(Pipeline):
...
@@ -154,6 +154,8 @@ class QuestionAnsweringPipeline(Pipeline):
return_attention_masks
=
True
,
return_input_lengths
=
False
return_attention_masks
=
True
,
return_input_lengths
=
False
)
)
token_type_ids
=
inputs
.
pop
(
'token_type_ids'
)
if
is_tf_available
():
if
is_tf_available
():
# TODO trace model
# TODO trace model
start
,
end
=
self
.
model
(
inputs
)
start
,
end
=
self
.
model
(
inputs
)
...
@@ -167,7 +169,7 @@ class QuestionAnsweringPipeline(Pipeline):
...
@@ -167,7 +169,7 @@ class QuestionAnsweringPipeline(Pipeline):
answers
=
[]
answers
=
[]
for
i
in
range
(
len
(
texts
)):
for
i
in
range
(
len
(
texts
)):
context_idx
=
inputs
[
'
token_type_ids
'
]
[
i
]
==
1
context_idx
=
token_type_ids
[
i
]
==
1
start_
,
end_
=
start
[
i
,
context_idx
],
end
[
i
,
context_idx
]
start_
,
end_
=
start
[
i
,
context_idx
],
end
[
i
,
context_idx
]
# Normalize logits and spans to retrieve the answer
# Normalize logits and spans to retrieve the answer
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment