Commit 6e61e060 authored by Morgan Funtowicz's avatar Morgan Funtowicz
Browse files

batch_encode_plus generates the encoder_attention_mask to avoid attending over padded values.

parent 02110485
......@@ -132,7 +132,7 @@ class QuestionAnsweringPipeline(Pipeline):
# Tabular input
if 'question' in kwargs and 'context' in kwargs:
texts = QuestionAnsweringPipeline.create_sample(kwargs['questions'], kwargs['contexts'])
texts = QuestionAnsweringPipeline.create_sample(kwargs['question'], kwargs['context'])
elif 'data' in kwargs:
texts = kwargs['data']
# Generic compatibility with sklearn and Keras
......@@ -156,7 +156,10 @@ class QuestionAnsweringPipeline(Pipeline):
)
# Remove special_tokens_mask to avoid KeyError
_ = inputs.pop('special_tokens_mask')
special_tokens_mask, input_len = inputs.pop('special_tokens_mask'), inputs.pop('input_len')
# TODO : Harmonize model arguments across all model
inputs['attention_mask'] = inputs.pop('encoder_attention_mask')
# if is_tf_available():
if False:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment