Commit 6e61e060 authored by Morgan Funtowicz's avatar Morgan Funtowicz
Browse files

batch_encode_plus generates the encoder_attention_mask to avoid attending over padded values.

parent 02110485
...@@ -132,7 +132,7 @@ class QuestionAnsweringPipeline(Pipeline): ...@@ -132,7 +132,7 @@ class QuestionAnsweringPipeline(Pipeline):
# Tabular input # Tabular input
if 'question' in kwargs and 'context' in kwargs: if 'question' in kwargs and 'context' in kwargs:
texts = QuestionAnsweringPipeline.create_sample(kwargs['questions'], kwargs['contexts']) texts = QuestionAnsweringPipeline.create_sample(kwargs['question'], kwargs['context'])
elif 'data' in kwargs: elif 'data' in kwargs:
texts = kwargs['data'] texts = kwargs['data']
# Generic compatibility with sklearn and Keras # Generic compatibility with sklearn and Keras
...@@ -156,7 +156,10 @@ class QuestionAnsweringPipeline(Pipeline): ...@@ -156,7 +156,10 @@ class QuestionAnsweringPipeline(Pipeline):
) )
# Remove special_tokens_mask to avoid KeyError # Remove special_tokens_mask to avoid KeyError
_ = inputs.pop('special_tokens_mask') special_tokens_mask, input_len = inputs.pop('special_tokens_mask'), inputs.pop('input_len')
# TODO : Harmonize model arguments across all model
inputs['attention_mask'] = inputs.pop('encoder_attention_mask')
# if is_tf_available(): # if is_tf_available():
if False: if False:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment