Commit 348e19aa authored by Morgan Funtowicz's avatar Morgan Funtowicz
Browse files

Expose attention_masks and input_lengths arguments to batch_encode_plus

parent c2407fdd
...@@ -149,14 +149,11 @@ class QuestionAnsweringPipeline(Pipeline): ...@@ -149,14 +149,11 @@ class QuestionAnsweringPipeline(Pipeline):
# Map to tuple (question, context) # Map to tuple (question, context)
texts = [(text['question'], text['context']) for text in texts] texts = [(text['question'], text['context']) for text in texts]
inputs = self.tokenizer.batch_encode_plus( inputs = self.tokenizer.batch_encode_plus(
texts, add_special_tokens=True, return_tensors='tf' if is_tf_available() else 'pt' texts, add_special_tokens=False, return_tensors='tf' if is_tf_available() else 'pt',
return_attention_masks=True, return_input_lengths=False
) )
# Remove special_tokens_mask to avoid KeyError
special_tokens_mask, input_len = inputs.pop('special_tokens_mask'), inputs.pop('input_len')
# TODO : Harmonize model arguments across all model # TODO : Harmonize model arguments across all model
inputs['attention_mask'] = inputs.pop('encoder_attention_mask') inputs['attention_mask'] = inputs.pop('encoder_attention_mask')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment