"src/treelearner/vscode:/vscode.git/clone" did not exist on "da3b4c19dc76ed76a191d7d70098f7080495000c"
util.py 13 KB
Newer Older
Rayyyyy's avatar
Rayyyyy committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
import os
import sys
import pytz
import json
import torch
import shutil
import pathlib
import time
import pickle
import logging
import string
import numpy as np
import pandas as pd
from contextlib import contextmanager
from dataclasses import dataclass
from transformers.tokenization_utils import PreTrainedTokenizer
from datetime import datetime
from collections import defaultdict, OrderedDict
from typing import Optional, Tuple, Union, List, Callable, Dict, Any, Mapping

logger = logging.getLogger(__name__)


@contextmanager
def do_nothing():
    yield

def makedirs(path):
    p = pathlib.Path(path)
    p.parent.mkdir(parents=True, exist_ok=True)
    return path

def clear_dir(directory):
    if not os.path.exists(directory):
        os.makedirs(directory, exist_ok=True)
    for filename in os.listdir(directory):
        file_path = os.path.join(directory, filename)
        try:
            if os.path.isfile(file_path) or os.path.islink(file_path):
                os.unlink(file_path)
            elif os.path.isdir(file_path):
                shutil.rmtree(file_path)
        except Exception as e:
            print('Failed to delete %s. Reason: %s' % (file_path, e))

def split_file_dir_name_ext(path):
    """Return the directory, name, and extension of a given file."""
    p = pathlib.Path(path)
    assert p.is_file()
    return p.parent, p.stem, p.suffix

def save_pickle(obj, path:str):
    """
    Save pickle file.
    """
    if not os.path.exists(path):
        makedirs(path)
    with open(path, "wb") as f:
        return pickle.dump(obj, f)

def load_pickle(path):
    with open(path, "rb") as f:
        return pickle.load(f)
    
def save_json(obj, path:str):
    if not os.path.exists(path):
        makedirs(path)
    with open(path, "w") as f:
        return json.dump(obj, f, ensure_ascii=False)

def load_json(path, lines=False):
    if lines:
        output = []
        with open(path, "r") as f:
            for line in f:
                output.append(json.loads(line))
        return output
    else:
        with open(path, "r") as f:
            return json.load(f)

@contextmanager
def filelock(path, process_index=0):
    while os.path.exists(path):
        if i == 0 and process_index == 0:
            logger.info("found lock, waiting for other programs...")
        time.sleep(3)
        i = 1
    if process_index == 0:
        save_json("this is a lock", path)
    yield
    if process_index == 0:
        os.remove(path)

def normalize_text(text, ignore_case=True, ignore_punctuation=True, ignore_space=True, ignore_number=False):
    if isinstance(text, str):
        text = [text]
        unpack = True
    else:
        unpack = False
    if ignore_case:
        text = np.char.lower(text)
    if ignore_punctuation:
        repl_table = string.punctuation.maketrans("", "", string.punctuation)
        text = np.char.translate(text, table=repl_table)
    if ignore_number:
        repl_table = string.digits.maketrans("", "", string.digits)
        text = np.char.translate(text, table=repl_table)
    if ignore_space:
        for i, words in enumerate(np.char.split(text)):
            text[i] = " ".join(words)
    if isinstance(text, np.ndarray):
        text = text.tolist()
    if unpack:
        text = text[0]
    return text

def min_max_normalize(array):
    return (array - array.min(-1)[:,None])/(array.max(-1) - array.min(-1))[:, None]

def get_max_length_in_nested_lists(lst):
    if len(lst) and isinstance(lst[0], list):
        lengths = []
        for elem in lst:
            length = get_max_length_in_nested_lists(elem)
            lengths.append(length)
        max_length = max(lengths)
        return max_length
    else:
        return len(lst)

def pad_nested_lists(lst, max_length, padding_value, padding_side="right"):
    if isinstance(lst, list) and len(lst) and isinstance(lst[0], list):
        masks = []
        for i, elem in enumerate(lst):
            lst[i], mask = pad_nested_lists(elem, max_length, padding_value, padding_side)
            masks.append(mask)
        return lst, masks
    elif isinstance(lst, list):
        if padding_side == "right":
            mask = [1] * len(lst) + [0] * (max_length - len(lst))
            lst = lst + [padding_value for _ in range(max_length - len(lst))]
            return lst, mask
        else:
            mask = [0] * (max_length - len(lst)) + [1] * len(lst)
            lst = [padding_value for _ in range(max_length - len(lst))] + lst
            return lst, mask
    else:
        raise NotImplementedError(f"Unrecognized type {lst}")

def mask_nested_lists(lst, mask_target, mask_value=0):
    if isinstance(lst[0], list):
        for i, elem in enumerate(lst):
            lst[i] = mask_nested_lists(elem, mask_target, mask_value)
        return lst
    else:
        return [x if x != mask_target else mask_value for x in lst]

def are_elements_of_same_length(lst: List):
    if not isinstance(lst[0], list):
        return False

    length = len(lst[0])
    return all(len(x) == length if isinstance(x, list) else False for x in lst)

def add_eos(inputs: Mapping, eos_token_id: int):
    for k, v in inputs.items():
        assert isinstance(v, list), f"Make sure the return_tensors are set to list!"
        if k == "input_ids":
            v = v + [eos_token_id]
        elif k == "position_ids":
            v = v + [v[-1] + 1]
        elif k in ["attention_mask", "token_type_ids"]:
            v = v + v[-1:]
        else:
            raise NotImplementedError(f"Inputs key {k} not implemented!")
        inputs[k] = v
    return inputs

def remove_eos(inputs: Mapping, eos_token_id: int):
    input_ids = inputs["input_ids"]
    eos_idx = [i for i, x in enumerate(input_ids) if x == eos_token_id][0]
    for k, v in inputs.items():
        inputs[k].pop(eos_idx)
    return inputs

def mix_parameters(models: List[torch.nn.Module], weights: Optional[List[float]]=None):
    """Mix parameters of different models according to given weights.
    
    Returns:
        the model with mixed parameters.
    """
    new_state_dict = OrderedDict()
    if weights is None:
        weights = [1 / len(models) for _ in range(len(models))]
    else:
        assert len(weights) == len(models), f"Make sure the size of mix weights equals to the number of models!"

    for name_param_pairs in zip(*[model.state_dict().items() for model in models]):
        names = [name_param_pair[0] for name_param_pair in name_param_pairs]
        params = [name_param_pair[1] for name_param_pair in name_param_pairs]

        assert all(name == names[0] for name in names), f"Found incompatible key in {names}!"
        name = names[0]
        mixed_param = None

        # there may be non-float parameters stored, which should not be mixed
        if params[0].dtype not in [torch.float16, torch.bfloat16, torch.float32]:
            assert all((param == params[0]).all() for param in params), f"Found incompatible value in non-float tensor {params}!"
            new_state_dict[name] = params[0]
            continue

        for weight, param in zip(weights, params):
            if mixed_param is None:
                mixed_param = weight * param
            else:
                mixed_param += weight * param
            new_state_dict[name] = mixed_param
            
    model = models[0]
    info = model.load_state_dict(new_state_dict)
    print(info)
    return model


class FileLogger:
    def __init__(self, log_file) -> None:
        self.log_file = log_file
    
    def log(self, metrics, **kwargs):
        with open(self.log_file, "a+") as f:
            # get current time
            tz = pytz.timezone('Asia/Shanghai')
            time = f"{'Time': <10}: {json.dumps(datetime.now(tz).strftime('%Y-%m-%d, %H:%M:%S'), ensure_ascii=False)}\n"
            command = f"{'Command': <10}: {json.dumps(' '.join(sys.argv), ensure_ascii=False)}\n"
            metrics = f"{'Metrics': <10}: {json.dumps(metrics, ensure_ascii=False)}\n"
            msg = time + command
            print(msg + metrics)

            for key, value in kwargs.items():
                try:
                    msg += f"{key: <10}: {json.dumps(value, ensure_ascii=False)}\n"
                except:
                    print(key)
                    print(value)
                    raise
            msg += metrics
            f.write(str(msg) + "\n")


class Sequential_Sampler:
    """
    The sampler used in creating sequential dataloader.
    """
    def __init__(self, dataset_length:int, num_replicas:int, rank:int) -> None:
        """
        Args:
            dataset_length: length of the dataset
            num_replicas: number of splits
            rank: the current process id

        Attributes:
            start: the starting index
            end: the ending index
        """
        super().__init__()
        len_per_worker = dataset_length / num_replicas
        # force to set rank==0 because when world_size==1 the local_rank is -1 by default
        if num_replicas == 1:
            rank = 0
        self.start = round(len_per_worker * rank)
        self.end = round(len_per_worker * (rank + 1))
        self.rank = rank

    def __iter__(self):
        start = self.start
        end = self.end
        return iter(range(start, end, 1))

    def __len__(self):
        return self.end - self.start


class DatasetProcessFn:
    """Wrapper for any user-defined process function for huggingface datasets.

    1. Process batched examples by looping the process function over them;
    2. Gather returned examples if any data augmentation happens with augment=True;
    3. Pass indices of examples inside the process function with _index keywords if they exist.

    The wrapped function should take in any needed columns and return a dict with 1 or more samples.
    """
    def __init__(self, augment=False):
        self.augment = augment

    def __call__(self, _process_fn):
        def process(*args):
            sample_or_batch_sample = args[0]
            if len(args) == 1:
                pass
            elif len(args) == 2:
                indices = args[1]
                # detach the slice so that _index will not be set in the original data
                sample_or_batch_sample = sample_or_batch_sample.copy()
                sample_or_batch_sample["_index"] = indices
            else:
                raise NotImplementedError(f"Found more than 2 arguments {args}!")

            keys = list(sample_or_batch_sample.keys())
            func_args = [sample_or_batch_sample[k] for k in keys]
            
            # FIXME: if all values in one sample are of the same length, this would fail
            if are_elements_of_same_length(func_args):
                outputs = defaultdict(list)
                for arg in zip(*func_args):
                    # get each element in a batch
                    kwargs = {keys[j]: arg[j] for j in range(len(arg))}
                    output = _process_fn(**kwargs)
                    if output is not None:
                        for k, v in output.items():
                            if self.augment:
                                outputs[k].extend(v)
                            else:
                                outputs[k].append(v)
            else:
                outputs = _process_fn(**sample_or_batch_sample)
                if outputs is None:
                    raise ValueError(f"Found None returned from process_fn. Make sure you set 'batched=True' when trying to augment/distract samples in the datasets!")
            return dict(outputs)
        return process


@dataclass
class DefaultDataCollator:
    """
    Data collator that can:
    1. Dynamically pad all inputs received. The inputs must be dict of lists.
    2. Add position_ids based on attention_mask if required.
    """
    tokenizer: PreTrainedTokenizer
    attention_padding_value: int = 0
    label_padding_value: int = -100
    add_position_ids: bool = False

    def __call__(self, batch_elem: List) -> Dict[str, Any]:
        first_elem = batch_elem[0]
        return_batch = {}
        
        for key, value in first_elem.items():
            # HACK: any key containing attention_mask must be attention_mask
            # important to assign different pad token for different types of inputs
            if "attention_mask" in key:
                pad_token_id = self.attention_padding_value
            elif "label" in key:
                pad_token_id = self.label_padding_value
            else:
                pad_token_id = self.tokenizer.pad_token_id

            batch_value = [elem[key] for elem in batch_elem]
            # pad all lists and nested lists
            if isinstance(value, list):
                max_length = get_max_length_in_nested_lists(batch_value)
                batch_value, _ = pad_nested_lists(batch_value, max_length, pad_token_id, self.tokenizer.padding_side)

            return_batch[key] = torch.tensor(batch_value)

            if "attention_mask" in key and self.add_position_ids:
                value = return_batch[key]
                position_ids = value.cumsum(-1) - 1
                position_ids = position_ids.masked_fill(value == 0, 0)
                return_batch[key.replace("attention_mask", "position_ids")] = position_ids
        return return_batch