preprocess_wsc.py 1.02 KB
Newer Older
lintangsutawika's avatar
lintangsutawika committed
1
import re
lintangsutawika's avatar
lintangsutawika committed
2
from lm_eval.utils import general_detokenize
lintangsutawika's avatar
lintangsutawika committed
3
4


lintangsutawika's avatar
lintangsutawika committed
5
def t5_prompt_doc_to_text(x):
lintangsutawika's avatar
lintangsutawika committed
6
    def _mark_span(text, span_str, span_idx, mark):
lintangsutawika's avatar
lintangsutawika committed
7
8
9
        pattern_tmpl = r"^((?:\S+\s){N})(W)"
        pattern = re.sub("N", str(span_idx), pattern_tmpl)
        pattern = re.sub("W", span_str, pattern)
10
        return re.sub(pattern, r"\1{0}\2{0}".format(mark), text)
lintangsutawika's avatar
lintangsutawika committed
11

lintangsutawika's avatar
lintangsutawika committed
12
    text = x["text"]
13
    text = _mark_span(text, x["span2_text"], x["span2_index"], "*")
lintangsutawika's avatar
lintangsutawika committed
14

15
    return "wsc: "+text
lintangsutawika's avatar
lintangsutawika committed
16
17


lintangsutawika's avatar
lintangsutawika committed
18
19
def default_doc_to_text(x):
    raw_passage = x["text"]
lintangsutawika's avatar
lintangsutawika committed
20
    # NOTE: HuggingFace span indices are word-based not character-based.
lintangsutawika's avatar
lintangsutawika committed
21
22
23
24
25
    pre = " ".join(raw_passage.split()[: x["span2_index"]])
    post = raw_passage[len(pre) + len(x["span2_text"]) + 1 :]
    passage = general_detokenize(pre + " *{}*".format(x["span2_text"]) + post)
    noun = x["span1_text"]
    pronoun = x["span2_text"]
lintangsutawika's avatar
lintangsutawika committed
26
27
28
29
30
31
    text = (
        f"Passage: {passage}\n"
        + f'Question: In the passage above, does the pronoun "*{pronoun}*" refer to "*{noun}*"?\n'
        + "Answer:"
    )
    return text