"examples/knwl_dialo/data_processing.sh" did not exist on "f322c788208c66ef5c38a4bc8b6a909f034c0889"
utils.py 1.87 KB
Newer Older
sdtblck's avatar
sdtblck committed
1
import os
Leo Gao's avatar
Leo Gao committed
2
import re
3
import collections
sdtblck's avatar
sdtblck committed
4
5
6
7
8
9
10
11
12
13
14


class ExitCodeError(Exception):
    pass


def sh(x):
    if os.system(x):
        raise ExitCodeError()


Jason Phang's avatar
gpt3  
Jason Phang committed
15
16
17
18
19
20
def simple_parse_args_string(args_string):
    """
    Parses something like
        args1=val1,arg2=val2
    Into a dictionary
    """
Jason Phang's avatar
Jason Phang committed
21
    args_string = args_string.strip()
Jason Phang's avatar
gpt3  
Jason Phang committed
22
23
24
25
    if not args_string:
        return {}
    arg_list = args_string.split(",")
    args_dict = {}
Jason Phang's avatar
Jason Phang committed
26
    for arg in arg_list:
Jason Phang's avatar
gpt3  
Jason Phang committed
27
28
29
        k, v = arg.split("=")
        args_dict[k] = v
    return args_dict
Leo Gao's avatar
Leo Gao committed
30
31
32

def join_iters(iters):
    for iter in iters:
Leo Gao's avatar
Leo Gao committed
33
        yield from iter
Leo Gao's avatar
Leo Gao committed
34
35
36
37
38
39
40
41
42
43


def chunks(iter, n):
    arr = []
    for x in iter:
        arr.append(x)
        if len(arr) == n:
            yield arr
            arr = []
    
Leo Gao's avatar
Leo Gao committed
44
45
    if arr: yield arr

46
47
48
49
50
51
52
53
def group(arr, fn):
    res = collections.defaultdict(list)

    for ob in arr:
        res[fn(ob)].append(ob)
    
    return list(res.values())

Leo Gao's avatar
Leo Gao committed
54
55
56
57
58
59
def general_detokenize(string):
    string = string.replace(" n't", "n't")
    string = string.replace(" )", ")")
    string = string.replace("( ", "(")
    string = string.replace("\" ", "\"")
    string = string.replace(" \"", "\"")
Leo Gao's avatar
Fix  
Leo Gao committed
60
    string = re.sub(r" (['.,])", r"\1", string)
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
    return string


class Reorderer:
    def __init__(self, arr, fn):
        self.size = len(arr)
        arr = list(enumerate(arr))
        arr = group(arr, lambda x: fn(x[1]))
        arr = [
            ([y[0] for y in x], x[0][1]) for x in arr
        ]
        arr.sort(key=lambda x: fn(x[1]))

        self.arr = arr
        
    
    def get_reordered(self):
        return [x[1] for x in self.arr]
    
    def get_original(self, newarr):
        res = [None] * self.size
        cov = [False] * self.size

        for (inds, _), v in zip(self.arr, newarr):
            for ind in inds: 
                res[ind] = v
                cov[ind] = True
        
        assert all(cov)
        
        return res