realm_dataset_utils.py 9.24 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
import itertools
import os
import random
import time

import numpy as np
import spacy
import torch

from megatron.data.dataset_utils import create_masked_lm_predictions, pad_and_convert_to_numpy
11
from megatron import get_args, get_tokenizer, print_rank_0, mpu
12
13
14
15
16
17
18

SPACY_NER = spacy.load('en_core_web_lg')


def build_realm_training_sample(sample, max_seq_length,
                                vocab_id_list, vocab_id_to_token_dict,
                                cls_id, sep_id, mask_id, pad_id,
19
                                masked_lm_prob, block_ner_mask, np_rng):
20
21
22
    tokens = list(itertools.chain(*sample))[:max_seq_length - 2]
    tokens, tokentypes = create_single_tokens_and_tokentypes(tokens, cls_id, sep_id)

23
24
    args = get_args()
    if args.use_regular_masking:
25
26
27
28
        max_predictions_per_seq = masked_lm_prob * max_seq_length
        masked_tokens, masked_positions, masked_labels, _ = create_masked_lm_predictions(
            tokens, vocab_id_list, vocab_id_to_token_dict, masked_lm_prob,
            cls_id, sep_id, mask_id, max_predictions_per_seq, np_rng)
29
30
31
32
33
34
35
36
37
38
39
40
41
42
    elif block_ner_mask is not None:
        block_ner_mask = list(itertools.chain(*block_ner_mask))[:max_seq_length - 2]
        block_ner_mask = [0] + block_ner_mask + [0]
        masked_tokens, masked_positions, masked_labels = get_arrays_using_ner_mask(tokens, block_ner_mask, mask_id)
    else:
        try:
            masked_tokens, masked_positions, masked_labels = salient_span_mask(tokens, mask_id)
        except TypeError:
            # this means the above returned None, and None isn't iterable.
            # TODO: consider coding style.
            max_predictions_per_seq = masked_lm_prob * max_seq_length
            masked_tokens, masked_positions, masked_labels, _ = create_masked_lm_predictions(
                tokens, vocab_id_list, vocab_id_to_token_dict, masked_lm_prob,
                cls_id, sep_id, mask_id, max_predictions_per_seq, np_rng)
43
44

    tokens_np, tokentypes_np, labels_np, padding_mask_np, loss_mask_np \
Neel Kant's avatar
Neel Kant committed
45
        = pad_and_convert_to_numpy(masked_tokens, tokentypes, masked_positions,
46
47
48
49
50
51
52
53
54
55
56
                                   masked_labels, pad_id, max_seq_length)

    train_sample = {
        'tokens': tokens_np,
        'labels': labels_np,
        'loss_mask': loss_mask_np,
        'pad_mask': padding_mask_np
    }
    return train_sample


57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
def get_arrays_using_ner_mask(tokens, block_ner_mask, mask_id):
    tokenizer = get_tokenizer()
    tokens_str = join_str_list(tokenizer.tokenizer.convert_ids_to_tokens(tokens))

    masked_tokens = tokens.copy()
    masked_positions = []
    masked_labels = []


    for i in range(len(tokens)):
        if block_ner_mask[i] == 1:
            masked_positions.append(i)
            masked_labels.append(tokens[i])
            masked_tokens[i] = mask_id

    # print("-" * 100 + '\n',
    #       "TOKEN STR\n", tokens_str + '\n',
    #       "OUTPUT\n", join_str_list(tokenizer.tokenizer.convert_ids_to_tokens(masked_tokens)), flush=True)

    return masked_tokens, masked_positions, masked_labels


79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
def create_single_tokens_and_tokentypes(_tokens, cls_id, sep_id):
    tokens = []
    tokens.append(cls_id)
    tokens.extend(list(_tokens))
    tokens.append(sep_id)
    tokentypes = [0] * len(tokens)
    return tokens, tokentypes


def join_str_list(str_list):
    """Join a list of strings, handling spaces appropriately"""
    result = ""
    for s in str_list:
        if s.startswith("##"):
            result += s[2:]
        else:
            result += " " + s
    return result


def id_to_str_pos_map(token_ids, tokenizer):
    """Given a list of ids, return a list of integers which correspond to the starting index
    of the corresponding token in the original string (with spaces, without artifacts e.g. ##)"""
    token_strs = tokenizer.tokenizer.convert_ids_to_tokens(token_ids)
    pos_map = [0]
    for i in range(len(token_strs) - 1):
        len_prev = len(token_strs[i])
        # do not add the length of the "##"
        if token_strs[i].startswith("##"):
            len_prev -= 2

        # add the length of the space if needed
        if token_strs[i + 1].startswith("##"):
            pos_map.append(pos_map[-1] + len_prev)
        else:
            pos_map.append(pos_map[-1] + len_prev + 1)

    # make sure total size is correct
    offset = -2 if token_strs[-1].startswith("##") else 0
    total_len = pos_map[-1] + len(token_strs[-1]) + offset
Neel Kant's avatar
Neel Kant committed
119
    assert total_len == len(join_str_list(token_strs)) - 1, (total_len, len(join_str_list(token_strs)))
120
121
122
123
124
125
126
127

    return pos_map


def salient_span_mask(tokens, mask_id):
    """Creates the predictions for the masked LM objective.
    Note: Tokens here are vocab ids and not text tokens."""
    tokenizer = get_tokenizer()
Neel Kant's avatar
Neel Kant committed
128
    tokens_str = join_str_list(tokenizer.tokenizer.convert_ids_to_tokens(tokens))
129
130
131

    # need to get all named entities
    entities = SPACY_NER(tokens_str).ents
Neel Kant's avatar
Neel Kant committed
132
133
    undesired_types = ['CARDINAL', 'TIME', 'PERCENT', 'MONEY', 'QUANTITY', 'ORDINAL']
    entities = [e for e in entities if e.text != "CLS" and e.label_ not in undesired_types]
134
135
    if len(entities) == 0:
        return None
Neel Kant's avatar
Neel Kant committed
136
137
    entity_idx = np.random.randint(0, len(entities))
    selected_entity = entities[entity_idx]
138
139

    token_pos_map = id_to_str_pos_map(tokens, tokenizer)
Neel Kant's avatar
Neel Kant committed
140
141
    mask_start = mask_end = 0
    set_mask_start = False
142
    while mask_end < len(token_pos_map) and token_pos_map[mask_end] < selected_entity.end_char:
Neel Kant's avatar
Neel Kant committed
143
144
145
146
        if token_pos_map[mask_start] > selected_entity.start_char:
            set_mask_start = True
        if not set_mask_start:
            mask_start += 1
147
        mask_end += 1
148
    masked_positions = list(range(mask_start - 1, mask_end))
149

Neel Kant's avatar
Neel Kant committed
150
    labels = []
151
    output_tokens = tokens.copy()
Neel Kant's avatar
Neel Kant committed
152
153
    for id_idx in masked_positions:
        labels.append(tokens[id_idx])
154
        output_tokens[id_idx] = mask_id
155
156
157
158
    # print("-" * 100 + '\n',
    #       "TOKEN STR\n", tokens_str + '\n',
    #       "SELECTED ENTITY\n", selected_entity.text + '\n',
    #       "OUTPUT\n", join_str_list(tokenizer.tokenizer.convert_ids_to_tokens(output_tokens)), flush=True)
159

Neel Kant's avatar
Neel Kant committed
160
    return output_tokens, masked_positions, labels
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223


def get_block_samples_mapping(block_dataset, title_dataset, data_prefix, num_epochs,
                              max_num_samples, max_seq_length, seed, name):
    if not num_epochs:
        if not max_num_samples:
            raise ValueError("Need to specify either max_num_samples "
                             "or num_epochs")
        num_epochs = np.iinfo(np.int32).max - 1
    if not max_num_samples:
        max_num_samples = np.iinfo(np.int64).max - 1

    # Filename of the index mapping
    indexmap_filename = data_prefix
    indexmap_filename += '_{}_indexmap'.format(name)
    if num_epochs != (np.iinfo(np.int32).max - 1):
        indexmap_filename += '_{}ep'.format(num_epochs)
    if max_num_samples != (np.iinfo(np.int64).max - 1):
        indexmap_filename += '_{}mns'.format(max_num_samples)
    indexmap_filename += '_{}msl'.format(max_seq_length)
    indexmap_filename += '_{}s'.format(seed)
    indexmap_filename += '.npy'

    # Build the indexed mapping if not exist.
    if torch.distributed.get_rank() == 0 and \
            not os.path.isfile(indexmap_filename):
        print(' > WARNING: could not find index map file {}, building '
              'the indices on rank 0 ...'.format(indexmap_filename))

        # Make sure the types match the helpers input types.
        assert block_dataset.doc_idx.dtype == np.int64
        assert block_dataset.sizes.dtype == np.int32

        # Build samples mapping
        verbose = torch.distributed.get_rank() == 0
        start_time = time.time()
        print_rank_0(' > building samples index mapping for {} ...'.format(
            name))
        from megatron.data.dataset_utils import compile_helper
        compile_helper()
        from megatron.data import helpers
        samples_mapping = helpers.build_blocks_mapping(
            block_dataset.doc_idx,
            block_dataset.sizes,
            title_dataset.sizes,
            num_epochs,
            max_num_samples,
            max_seq_length-3,  # account for added tokens
            seed,
            verbose)
        print_rank_0(' > done building samples index mapping')
        np.save(indexmap_filename, samples_mapping, allow_pickle=True)
        print_rank_0(' > saved the index mapping in {}'.format(
            indexmap_filename))
        # Make sure all the ranks have built the mapping
        print_rank_0(' > elapsed time to build and save samples mapping '
                     '(seconds): {:4f}'.format(
            time.time() - start_time))
    # This should be a barrier but nccl barrier assumes
    # device_index=rank which is not the case for model
    # parallel case
    counts = torch.cuda.LongTensor([1])
    torch.distributed.all_reduce(counts, group=mpu.get_data_parallel_group())
Neel Kant's avatar
Neel Kant committed
224
225
    assert counts[0].item() == torch.distributed.get_world_size(
        group=mpu.get_data_parallel_group())
226
227
228
229
230
231
232
233
234
235
236
237

    # Load indexed dataset.
    print_rank_0(' > loading indexed mapping from {}'.format(
        indexmap_filename))
    start_time = time.time()
    samples_mapping = np.load(indexmap_filename, allow_pickle=True)
    print_rank_0('    loaded indexed file in {:3.3f} seconds'.format(
        time.time() - start_time))
    print_rank_0('    total number of samples: {}'.format(
        samples_mapping.shape[0]))

    return samples_mapping