preprocess_wsc.py 1.21 KB
Newer Older
lintangsutawika's avatar
lintangsutawika committed
1
import re
lintangsutawika's avatar
lintangsutawika committed
2
from lm_eval.utils import general_detokenize
lintangsutawika's avatar
lintangsutawika committed
3
4


lintangsutawika's avatar
lintangsutawika committed
5
def t5_prompt_doc_to_text(x):
lintangsutawika's avatar
lintangsutawika committed
6
    def _mark_span(text, span_str, span_idx, mark):
lintangsutawika's avatar
lintangsutawika committed
7
8
9
10
        pattern_tmpl = r"^((?:\S+\s){N})(W)"
        pattern = re.sub("N", str(span_idx), pattern_tmpl)
        pattern = re.sub("W", span_str, pattern)
        return re.sub(pattern, r"\1{0} \2 {0}".format(mark), text)
lintangsutawika's avatar
lintangsutawika committed
11

lintangsutawika's avatar
lintangsutawika committed
12
13
    text = x["text"]
    text = _mark_span(text, x["span1_text"], x["span1_index"], "*")
lintangsutawika's avatar
lintangsutawika committed
14
    # Compensate for 2 added "words" added in previous step.
lintangsutawika's avatar
lintangsutawika committed
15
16
    span2_index = x["span2_index"] + 2 * (x["span1_index"] < x["span2_index"])
    text = _mark_span(text, x["span2_text"], span2_index, "#")
lintangsutawika's avatar
lintangsutawika committed
17

lintangsutawika's avatar
lintangsutawika committed
18
    return text
lintangsutawika's avatar
lintangsutawika committed
19
20


lintangsutawika's avatar
lintangsutawika committed
21
22
def default_doc_to_text(x):
    raw_passage = x["text"]
lintangsutawika's avatar
lintangsutawika committed
23
    # NOTE: HuggingFace span indices are word-based not character-based.
lintangsutawika's avatar
lintangsutawika committed
24
25
26
27
28
    pre = " ".join(raw_passage.split()[: x["span2_index"]])
    post = raw_passage[len(pre) + len(x["span2_text"]) + 1 :]
    passage = general_detokenize(pre + " *{}*".format(x["span2_text"]) + post)
    noun = x["span1_text"]
    pronoun = x["span2_text"]
lintangsutawika's avatar
lintangsutawika committed
29
30
31
32
33
34
    text = (
        f"Passage: {passage}\n"
        + f'Question: In the passage above, does the pronoun "*{pronoun}*" refer to "*{noun}*"?\n'
        + "Answer:"
    )
    return text