Commit a7d3794a authored by Morgan Funtowicz's avatar Morgan Funtowicz
Browse files

Remove token_type_ids for compatibility with DistilBert

parent fe0f552e
......@@ -20,7 +20,7 @@ from typing import Union, Optional, Tuple, List, Dict
import numpy as np
from transformers import is_tf_available, logger, AutoTokenizer, PreTrainedTokenizer, is_torch_available
from transformers import is_tf_available, is_torch_available, logger, AutoTokenizer, PreTrainedTokenizer
if is_tf_available():
from transformers import TFAutoModelForSequenceClassification, TFAutoModelForQuestionAnswering
......@@ -154,6 +154,8 @@ class QuestionAnsweringPipeline(Pipeline):
return_attention_masks=True, return_input_lengths=False
)
token_type_ids = inputs.pop('token_type_ids')
if is_tf_available():
# TODO trace model
start, end = self.model(inputs)
......@@ -167,7 +169,7 @@ class QuestionAnsweringPipeline(Pipeline):
answers = []
for i in range(len(texts)):
context_idx = inputs['token_type_ids'][i] == 1
context_idx = token_type_ids[i] == 1
start_, end_ = start[i, context_idx], end[i, context_idx]
# Normalize logits and spans to retrieve the answer
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment