Commit c2407fdd authored by Morgan Funtowicz's avatar Morgan Funtowicz
Browse files

Enable the Tensorflow backend.

parent f116cf59
......@@ -151,8 +151,7 @@ class QuestionAnsweringPipeline(Pipeline):
texts = [(text['question'], text['context']) for text in texts]
inputs = self.tokenizer.batch_encode_plus(
# texts, add_special_tokens=True, return_tensors='tf' if is_tf_available() else 'pt'
texts, add_special_tokens=True, return_tensors='pt'
texts, add_special_tokens=True, return_tensors='tf' if is_tf_available() else 'pt'
)
# Remove special_tokens_mask to avoid KeyError
......@@ -161,10 +160,10 @@ class QuestionAnsweringPipeline(Pipeline):
# TODO : Harmonize model arguments across all model
inputs['attention_mask'] = inputs.pop('encoder_attention_mask')
# if is_tf_available():
if False:
if is_tf_available():
# TODO trace model
start, end = self.model(inputs)
start, end = start.numpy(), end.numpy()
else:
import torch
with torch.no_grad():
......@@ -204,9 +203,7 @@ class QuestionAnsweringPipeline(Pipeline):
# Remove candidate with end < start and end - start > max_answer_len
candidates = np.tril(np.triu(outer), max_answer_len - 1)
# start = np.max(candidates, axis=2).argmax(-1)
# end = np.max(candidates, axis=1).argmax(-1)
# Inspired by Chen & al. (https://github.com/facebookresearch/DrQA)
scores_flat = candidates.flatten()
if topk == 1:
idx_sort = [np.argmax(scores_flat)]
......@@ -257,7 +254,7 @@ SUPPORTED_TASKS = {
},
'question-answering': {
'impl': QuestionAnsweringPipeline,
# 'tf': TFAutoModelForQuestionAnswering if is_tf_available() else None,
'tf': TFAutoModelForQuestionAnswering if is_tf_available() else None,
'pt': AutoModelForQuestionAnswering if is_torch_available() else None
}
}
......@@ -280,8 +277,7 @@ def pipeline(task: str, model, tokenizer: Optional[Union[str, PreTrainedTokenize
raise KeyError("Unknown task {}, available tasks are {}".format(task, list(SUPPORTED_TASKS.keys())))
targeted_task = SUPPORTED_TASKS[task]
# task, allocator = targeted_task['impl'], targeted_task['tf'] if is_tf_available() else targeted_task['pt']
task, allocator = targeted_task['impl'], targeted_task['pt']
task, allocator = targeted_task['impl'], targeted_task['tf'] if is_tf_available() else targeted_task['pt']
model = allocator.from_pretrained(model)
return task(model, tokenizer, **kwargs)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment