Unverified Commit aec1ca3a authored by jiqing-feng's avatar jiqing-feng Committed by GitHub
Browse files

[Bug Fix] fix qa pipeline tensor to numpy (#31585)

* fix qa pipeline

* fix tensor to numpy
parent c1e139c2
...@@ -118,7 +118,7 @@ def select_starts_ends( ...@@ -118,7 +118,7 @@ def select_starts_ends(
max_answer_len (`int`): Maximum size of the answer to extract from the model's output. max_answer_len (`int`): Maximum size of the answer to extract from the model's output.
""" """
# Ensure padded tokens & question tokens cannot belong to the set of candidate answers. # Ensure padded tokens & question tokens cannot belong to the set of candidate answers.
undesired_tokens = np.abs(np.array(p_mask) - 1) undesired_tokens = np.abs(p_mask.numpy() - 1)
if attention_mask is not None: if attention_mask is not None:
undesired_tokens = undesired_tokens & attention_mask undesired_tokens = undesired_tokens & attention_mask
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment