realm_dataset_utils.py 13.1 KB
Newer Older
1
2
3
4
5
6
7
8
import itertools
import os
import random
import time

import numpy as np
import spacy
import torch
9
10
11
12
13
14
try:
    import stanza
    processors_dict = {'tokenize': 'default', 'mwt': 'default', 'ner': 'conll03'}
    stanza_pipeline = stanza.Pipeline('en', processors=processors_dict, use_gpu=True)
except:
    pass
15
16

from megatron.data.dataset_utils import create_masked_lm_predictions, pad_and_convert_to_numpy
17
from megatron import get_args, get_tokenizer, print_rank_0, mpu
18
19
20
21
22
23
24

SPACY_NER = spacy.load('en_core_web_lg')


def build_realm_training_sample(sample, max_seq_length,
                                vocab_id_list, vocab_id_to_token_dict,
                                cls_id, sep_id, mask_id, pad_id,
25
26
                                masked_lm_prob, block_ner_mask, cased_tokens,
                                cased_tokenizer, np_rng):
27
28
29
    tokens = list(itertools.chain(*sample))[:max_seq_length - 2]
    tokens, tokentypes = create_single_tokens_and_tokentypes(tokens, cls_id, sep_id)

30
31
    args = get_args()
    if args.use_regular_masking:
32
33
34
35
        max_predictions_per_seq = masked_lm_prob * max_seq_length
        masked_tokens, masked_positions, masked_labels, _ = create_masked_lm_predictions(
            tokens, vocab_id_list, vocab_id_to_token_dict, masked_lm_prob,
            cls_id, sep_id, mask_id, max_predictions_per_seq, np_rng)
36
37
    elif block_ner_mask is not None:
        block_ner_mask = list(itertools.chain(*block_ner_mask))[:max_seq_length - 2]
Mohammad's avatar
Mohammad committed
38
39
40
        if args.use_random_spans:
            rand_idx = np.random.randint(len(block_ner_mask))
            block_ner_mask = block_ner_mask[rand_idx:] + block_ner_mask[:rand_idx]
41
42
43
44
        block_ner_mask = [0] + block_ner_mask + [0]
        masked_tokens, masked_positions, masked_labels = get_arrays_using_ner_mask(tokens, block_ner_mask, mask_id)
    else:
        try:
45
46
47
48
49
50
51
52
53
54
55
56
57
58
            if args.cased_data_path is not None:
                total_len = sum(len(l) for l in sample)
                # truncate the last sentence to make it so that the whole thing has length max_seq_length - 2
                if total_len > max_seq_length - 2:
                    offset = -(total_len - (max_seq_length - 2))
                    sample[-1] = sample[-1][:offset]
                masked_tokens, masked_positions, masked_labels = get_stanza_ner_mask(sample, cased_tokens, cased_tokenizer,
                                                                                     cls_id, sep_id, mask_id)
            else:
                masked_tokens, masked_positions, masked_labels = salient_span_mask(tokens, mask_id)
        except:
            # print("+" * 100, flush=True)
            # print('could not create salient span', flush=True)
            # print("+" * 100, flush=True)
59
60
61
62
63
64
            # this means the above returned None, and None isn't iterable.
            # TODO: consider coding style.
            max_predictions_per_seq = masked_lm_prob * max_seq_length
            masked_tokens, masked_positions, masked_labels, _ = create_masked_lm_predictions(
                tokens, vocab_id_list, vocab_id_to_token_dict, masked_lm_prob,
                cls_id, sep_id, mask_id, max_predictions_per_seq, np_rng)
65
66

    tokens_np, tokentypes_np, labels_np, padding_mask_np, loss_mask_np \
Neel Kant's avatar
Neel Kant committed
67
        = pad_and_convert_to_numpy(masked_tokens, tokentypes, masked_positions,
68
69
70
71
72
73
74
75
76
77
78
                                   masked_labels, pad_id, max_seq_length)

    train_sample = {
        'tokens': tokens_np,
        'labels': labels_np,
        'loss_mask': loss_mask_np,
        'pad_mask': padding_mask_np
    }
    return train_sample


79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
def get_stanza_ner_mask(tokens, cased_tokens, cased_tokenizer, cls_id, sep_id, mask_id):
    """Use stanza to generate NER salient span masks in the loop"""
    # assuming that the default tokenizer is uncased.
    uncased_tokenizer = get_tokenizer()
    block_ner_mask = []

    for cased_sent_ids, uncased_sent_ids in zip(cased_tokens, tokens):
        # print('>')
        token_pos_map = id_to_str_pos_map(uncased_sent_ids, uncased_tokenizer)

        # get the cased string and do NER with both toolkits
        cased_sent_str = join_str_list(cased_tokenizer.tokenizer.convert_ids_to_tokens(cased_sent_ids))
        entities = stanza_pipeline(cased_sent_str).ents
        spacy_entities = SPACY_NER(cased_sent_str).ents

        # CoNLL doesn't do dates, so we scan with spacy to get the dates.
        entities = [e for e in entities if e.text != 'CLS']
        entities.extend([e for e in spacy_entities if (e.text != 'CLS' and e.label_ == 'DATE')])

        # randomize which entities to look at, and set a target of 12% of tokens being masked
        entity_indices = np.arange(len(entities))
        np.random.shuffle(entity_indices)
        target_num_masks = int(len(cased_sent_ids) * 0.12)

        masked_positions = []
        for entity_idx in entity_indices[:3]:

            # if we have enough masks then break.
            if len(masked_positions) > target_num_masks:
                break

            selected_entity = entities[entity_idx]
            # print(">> selected entity: {}".format(selected_entity.text), flush=True)

            mask_start = mask_end = 0
            set_mask_start = False
            # loop for checking where mask should start and end.
            while mask_end < len(token_pos_map) and token_pos_map[mask_end] < selected_entity.end_char:
                if token_pos_map[mask_start] > selected_entity.start_char:
                    set_mask_start = True
                if not set_mask_start:
                    mask_start += 1
                mask_end += 1

            # add offset to indices since our input was list of sentences
            masked_positions.extend(range(mask_start - 1, mask_end))

        ner_mask = [0] * len(uncased_sent_ids)
        for pos in masked_positions:
            ner_mask[pos] = 1
        block_ner_mask.extend(ner_mask)

    # len_tokens = [len(l) for l in tokens]
    # print(len_tokens, flush=True)
    # print([sum(len_tokens[:i + 1]) for i in range(len(tokens))], flush=True)
    tokens = list(itertools.chain(*tokens))
    tokens = [cls_id] + tokens + [sep_id]
    block_ner_mask = [0] + block_ner_mask + [0]
    return get_arrays_using_ner_mask(tokens, block_ner_mask, mask_id)


140
141
142
143
144
145
146
147
148
149
150
151
152
153
def get_arrays_using_ner_mask(tokens, block_ner_mask, mask_id):
    tokenizer = get_tokenizer()
    tokens_str = join_str_list(tokenizer.tokenizer.convert_ids_to_tokens(tokens))

    masked_tokens = tokens.copy()
    masked_positions = []
    masked_labels = []

    for i in range(len(tokens)):
        if block_ner_mask[i] == 1:
            masked_positions.append(i)
            masked_labels.append(tokens[i])
            masked_tokens[i] = mask_id

154
155
156
157
158
    # print("\nTOKEN STR\n", tokens_str + '\n',
    #     "OUTPUT\n", join_str_list(tokenizer.tokenizer.convert_ids_to_tokens(masked_tokens)) + '\n',
    #     "FRAC_MASKED: {}\n".format(len(masked_labels) / len(tokens)),
    #     "-" * 100 + '\n',
    #     flush=True)
159
160
161
162

    return masked_tokens, masked_positions, masked_labels


163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
def create_single_tokens_and_tokentypes(_tokens, cls_id, sep_id):
    tokens = []
    tokens.append(cls_id)
    tokens.extend(list(_tokens))
    tokens.append(sep_id)
    tokentypes = [0] * len(tokens)
    return tokens, tokentypes


def join_str_list(str_list):
    """Join a list of strings, handling spaces appropriately"""
    result = ""
    for s in str_list:
        if s.startswith("##"):
            result += s[2:]
        else:
            result += " " + s
    return result


def id_to_str_pos_map(token_ids, tokenizer):
    """Given a list of ids, return a list of integers which correspond to the starting index
    of the corresponding token in the original string (with spaces, without artifacts e.g. ##)"""
    token_strs = tokenizer.tokenizer.convert_ids_to_tokens(token_ids)
    pos_map = [0]
    for i in range(len(token_strs) - 1):
        len_prev = len(token_strs[i])
        # do not add the length of the "##"
        if token_strs[i].startswith("##"):
            len_prev -= 2

        # add the length of the space if needed
        if token_strs[i + 1].startswith("##"):
            pos_map.append(pos_map[-1] + len_prev)
        else:
            pos_map.append(pos_map[-1] + len_prev + 1)

    # make sure total size is correct
    offset = -2 if token_strs[-1].startswith("##") else 0
    total_len = pos_map[-1] + len(token_strs[-1]) + offset
Neel Kant's avatar
Neel Kant committed
203
    assert total_len == len(join_str_list(token_strs)) - 1, (total_len, len(join_str_list(token_strs)))
204
205
206
207
208
209
210
211

    return pos_map


def salient_span_mask(tokens, mask_id):
    """Creates the predictions for the masked LM objective.
    Note: Tokens here are vocab ids and not text tokens."""
    tokenizer = get_tokenizer()
Neel Kant's avatar
Neel Kant committed
212
    tokens_str = join_str_list(tokenizer.tokenizer.convert_ids_to_tokens(tokens))
213
214
215

    # need to get all named entities
    entities = SPACY_NER(tokens_str).ents
Neel Kant's avatar
Neel Kant committed
216
217
    undesired_types = ['CARDINAL', 'TIME', 'PERCENT', 'MONEY', 'QUANTITY', 'ORDINAL']
    entities = [e for e in entities if e.text != "CLS" and e.label_ not in undesired_types]
218
219
    if len(entities) == 0:
        return None
Neel Kant's avatar
Neel Kant committed
220
221
    entity_idx = np.random.randint(0, len(entities))
    selected_entity = entities[entity_idx]
222
223

    token_pos_map = id_to_str_pos_map(tokens, tokenizer)
Neel Kant's avatar
Neel Kant committed
224
225
    mask_start = mask_end = 0
    set_mask_start = False
226
    while mask_end < len(token_pos_map) and token_pos_map[mask_end] < selected_entity.end_char:
Neel Kant's avatar
Neel Kant committed
227
228
229
230
        if token_pos_map[mask_start] > selected_entity.start_char:
            set_mask_start = True
        if not set_mask_start:
            mask_start += 1
231
        mask_end += 1
232
    masked_positions = list(range(mask_start - 1, mask_end))
233

Neel Kant's avatar
Neel Kant committed
234
    labels = []
235
    output_tokens = tokens.copy()
Neel Kant's avatar
Neel Kant committed
236
237
    for id_idx in masked_positions:
        labels.append(tokens[id_idx])
238
        output_tokens[id_idx] = mask_id
239
240
241
242
    # print("-" * 100 + '\n',
    #       "TOKEN STR\n", tokens_str + '\n',
    #       "SELECTED ENTITY\n", selected_entity.text + '\n',
    #       "OUTPUT\n", join_str_list(tokenizer.tokenizer.convert_ids_to_tokens(output_tokens)), flush=True)
243

Neel Kant's avatar
Neel Kant committed
244
    return output_tokens, masked_positions, labels
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268


def get_block_samples_mapping(block_dataset, title_dataset, data_prefix, num_epochs,
                              max_num_samples, max_seq_length, seed, name):
    if not num_epochs:
        if not max_num_samples:
            raise ValueError("Need to specify either max_num_samples "
                             "or num_epochs")
        num_epochs = np.iinfo(np.int32).max - 1
    if not max_num_samples:
        max_num_samples = np.iinfo(np.int64).max - 1

    # Filename of the index mapping
    indexmap_filename = data_prefix
    indexmap_filename += '_{}_indexmap'.format(name)
    if num_epochs != (np.iinfo(np.int32).max - 1):
        indexmap_filename += '_{}ep'.format(num_epochs)
    if max_num_samples != (np.iinfo(np.int64).max - 1):
        indexmap_filename += '_{}mns'.format(max_num_samples)
    indexmap_filename += '_{}msl'.format(max_seq_length)
    indexmap_filename += '_{}s'.format(seed)
    indexmap_filename += '.npy'

    # Build the indexed mapping if not exist.
Mohammad's avatar
Mohammad committed
269
    if mpu.get_data_parallel_rank() == 0 and \
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
            not os.path.isfile(indexmap_filename):
        print(' > WARNING: could not find index map file {}, building '
              'the indices on rank 0 ...'.format(indexmap_filename))

        # Make sure the types match the helpers input types.
        assert block_dataset.doc_idx.dtype == np.int64
        assert block_dataset.sizes.dtype == np.int32

        # Build samples mapping
        verbose = torch.distributed.get_rank() == 0
        start_time = time.time()
        print_rank_0(' > building samples index mapping for {} ...'.format(
            name))
        from megatron.data.dataset_utils import compile_helper
        compile_helper()
        from megatron.data import helpers
        samples_mapping = helpers.build_blocks_mapping(
            block_dataset.doc_idx,
            block_dataset.sizes,
            title_dataset.sizes,
            num_epochs,
            max_num_samples,
            max_seq_length-3,  # account for added tokens
            seed,
            verbose)
        print_rank_0(' > done building samples index mapping')
        np.save(indexmap_filename, samples_mapping, allow_pickle=True)
        print_rank_0(' > saved the index mapping in {}'.format(
            indexmap_filename))
        # Make sure all the ranks have built the mapping
        print_rank_0(' > elapsed time to build and save samples mapping '
                     '(seconds): {:4f}'.format(
            time.time() - start_time))
    # This should be a barrier but nccl barrier assumes
    # device_index=rank which is not the case for model
    # parallel case
    counts = torch.cuda.LongTensor([1])
    torch.distributed.all_reduce(counts, group=mpu.get_data_parallel_group())
Neel Kant's avatar
Neel Kant committed
308
309
    assert counts[0].item() == torch.distributed.get_world_size(
        group=mpu.get_data_parallel_group())
310
311
312
313
314
315
316
317
318
319
320
321

    # Load indexed dataset.
    print_rank_0(' > loading indexed mapping from {}'.format(
        indexmap_filename))
    start_time = time.time()
    samples_mapping = np.load(indexmap_filename, allow_pickle=True)
    print_rank_0('    loaded indexed file in {:3.3f} seconds'.format(
        time.time() - start_time))
    print_rank_0('    total number of samples: {}'.format(
        samples_mapping.shape[0]))

    return samples_mapping