Commit 0d6c17fc authored by Morgan Funtowicz's avatar Morgan Funtowicz
Browse files

black formatting

parent f26a3530
...@@ -705,14 +705,17 @@ class QuestionAnsweringPipeline(Pipeline): ...@@ -705,14 +705,17 @@ class QuestionAnsweringPipeline(Pipeline):
# Convert inputs to features # Convert inputs to features
examples = self._args_parser(*texts, **kwargs) examples = self._args_parser(*texts, **kwargs)
features_list = [ squad_convert_examples_to_features( features_list = [
squad_convert_examples_to_features(
[example], [example],
self.tokenizer, self.tokenizer,
kwargs["max_seq_len"], kwargs["max_seq_len"],
kwargs["doc_stride"], kwargs["doc_stride"],
kwargs["max_question_len"], kwargs["max_question_len"],
False False,
) for example in examples ] )
for example in examples
]
all_answers = [] all_answers = []
for features, example in zip(features_list, examples): for features, example in zip(features_list, examples):
fw_args = self.inputs_for_model([f.__dict__ for f in features]) fw_args = self.inputs_for_model([f.__dict__ for f in features])
...@@ -737,7 +740,10 @@ class QuestionAnsweringPipeline(Pipeline): ...@@ -737,7 +740,10 @@ class QuestionAnsweringPipeline(Pipeline):
end_ = np.exp(end_) / np.sum(np.exp(end_)) end_ = np.exp(end_) / np.sum(np.exp(end_))
# Mask padding and question # Mask padding and question
start_, end_ = start_ * np.abs(np.array(feature.p_mask) - 1), end_ * np.abs(np.array(feature.p_mask) - 1) start_, end_ = (
start_ * np.abs(np.array(feature.p_mask) - 1),
end_ * np.abs(np.array(feature.p_mask) - 1),
)
# TODO : What happens if not possible # TODO : What happens if not possible
# Mask CLS # Mask CLS
...@@ -758,8 +764,8 @@ class QuestionAnsweringPipeline(Pipeline): ...@@ -758,8 +764,8 @@ class QuestionAnsweringPipeline(Pipeline):
} }
for s, e, score in zip(starts, ends, scores) for s, e, score in zip(starts, ends, scores)
] ]
answers = sorted(answers, key = lambda x:x['score'], reverse=True)[:kwargs["topk"]] answers = sorted(answers, key=lambda x: x["score"], reverse=True)[: kwargs["topk"]]
all_answers+=answers all_answers += answers
if len(all_answers) == 1: if len(all_answers) == 1:
return all_answers[0] return all_answers[0]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment