# coding=utf-8 # Copyright 2018 The Google AI Language Team Authors, and NVIDIA. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import collections import numpy as np def compile_helper(): """Compile helper function ar runtime. Make sure this is invoked on a single process.""" import os import subprocess path = os.path.abspath(os.path.dirname(__file__)) ret = subprocess.run(['make', '-C', path]) if ret.returncode != 0: print("Making C++ dataset helpers module failed, exiting.") import sys sys.exit(1) def build_training_sample(sample, target_seq_length, max_seq_length, vocab_id_list, vocab_id_to_token_dict, cls_id, sep_id, mask_id, pad_id, masked_lm_prob, np_rng): """Biuld training sample. Arguments: sample: A list of sentences in which each sentence is a list token ids. target_seq_length: Desired sequence length. max_seq_length: Maximum length of the sequence. All values are padded to this length. vocab_id_list: List of vocabulary ids. Used to pick a random id. vocab_id_to_token_dict: A dictionary from vocab ids to text tokens. cls_id: Start of example id. sep_id: Separator id. mask_id: Mask token id. pad_id: Padding token id. masked_lm_prob: Probability to mask tokens. np_rng: Random number genenrator. Note that this rng state should be numpy and not python since python randint is inclusive for the opper bound whereas the numpy one is exclusive. """ # We assume that we have at least two sentences in the sample assert len(sample) > 1 assert target_seq_length <= max_seq_length # Divide sample into two segments (A and B). tokens_a, tokens_b, is_next_random = get_a_and_b_segments(sample, np_rng) # Truncate to `target_sequence_length`. max_num_tokens = target_seq_length truncated = truncate_segments(tokens_a, tokens_b, len(tokens_a), len(tokens_b), max_num_tokens, np_rng) # Build tokens and toketypes. tokens, tokentypes = create_tokens_and_tokentypes(tokens_a, tokens_b, cls_id, sep_id) # Masking. max_predictions_per_seq = masked_lm_prob * max_num_tokens (tokens, masked_positions, masked_labels, _) = create_masked_lm_predictions( tokens, vocab_id_list, vocab_id_to_token_dict, masked_lm_prob, cls_id, sep_id, mask_id, max_predictions_per_seq, np_rng) # Padding. tokens_np, tokentypes_np, labels_np, padding_mask_np, loss_mask_np \ = pad_and_convert_to_numpy(tokens, tokentypes, masked_positions, masked_labels, pad_id, max_seq_length) train_sample = { 'text': tokens_np, 'types': tokentypes_np, 'labels': labels_np, 'is_random': int(is_next_random), 'loss_mask': loss_mask_np, 'padding_mask': padding_mask_np, 'truncated': int(truncated)} return train_sample def get_a_and_b_segments(sample, np_rng): """Divide sample into a and b segments.""" # Number of sentences in the sample. n_sentences = len(sample) # Make sure we always have two sentences. assert n_sentences > 1, 'make sure each sample has at least two sentences.' # First part: # `a_end` is how many sentences go into the `A`. a_end = 1 if n_sentences >= 3: # Note that randin in numpy is exclusive. a_end = np_rng.randint(1, n_sentences) tokens_a = [] for j in range(a_end): tokens_a.extend(sample[j]) # Second part: tokens_b = [] for j in range(a_end, n_sentences): tokens_b.extend(sample[j]) # Random next: is_next_random = False if np_rng.random() < 0.5: is_next_random = True tokens_a, tokens_b = tokens_b, tokens_a return tokens_a, tokens_b, is_next_random def truncate_segments(tokens_a, tokens_b, len_a, len_b, max_num_tokens, np_rng): """Truncates a pair of sequences to a maximum sequence length.""" #print(len_a, len_b, max_num_tokens) assert len_a > 0 assert len_b > 0 if len_a + len_b <= max_num_tokens: return False while len_a + len_b > max_num_tokens: if len_a > len_b: len_a -= 1 tokens = tokens_a else: len_b -= 1 tokens = tokens_b if np_rng.random() < 0.5: del tokens[0] else: tokens.pop() return True def create_tokens_and_tokentypes(tokens_a, tokens_b, cls_id, sep_id): """Merge segments A and B, add [CLS] and [SEP] and build tokentypes.""" tokens = [] tokentypes = [] # [CLS]. tokens.append(cls_id) tokentypes.append(0) # Segment A. for token in tokens_a: tokens.append(token) tokentypes.append(0) # [SEP]. tokens.append(sep_id) tokentypes.append(0) # Segment B. for token in tokens_b: tokens.append(token) tokentypes.append(1) # [SEP]. tokens.append(sep_id) tokentypes.append(1) return tokens, tokentypes MaskedLmInstance = collections.namedtuple("MaskedLmInstance", ["index", "label"]) def is_start_piece(piece): """Check if the current word piece is the starting piece (BERT).""" # When a word has been split into # WordPieces, the first token does not have any marker and any subsequence # tokens are prefixed with ##. So whenever we see the ## token, we # append it to the previous set of word indexes. return not piece.startswith("##") def create_masked_lm_predictions(tokens, vocab_id_list, vocab_id_to_token_dict, masked_lm_prob, cls_id, sep_id, mask_id, max_predictions_per_seq, np_rng, max_ngrams=3, do_whole_word_mask=True, favor_longer_ngram=False, do_permutation=False): """Creates the predictions for the masked LM objective. Note: Tokens here are vocab ids and not text tokens.""" cand_indexes = [] # Note(mingdachen): We create a list for recording if the piece is # the starting piece of current token, where 1 means true, so that # on-the-fly whole word masking is possible. token_boundary = [0] * len(tokens) for (i, token) in enumerate(tokens): if token == cls_id or token == sep_id: token_boundary[i] = 1 continue # Whole Word Masking means that if we mask all of the wordpieces # corresponding to an original word. # # Note that Whole Word Masking does *not* change the training code # at all -- we still predict each WordPiece independently, softmaxed # over the entire vocabulary. if (do_whole_word_mask and len(cand_indexes) >= 1 and not is_start_piece(vocab_id_to_token_dict[token])): cand_indexes[-1].append(i) else: cand_indexes.append([i]) if is_start_piece(vocab_id_to_token_dict[token]): token_boundary[i] = 1 output_tokens = list(tokens) masked_lm_positions = [] masked_lm_labels = [] if masked_lm_prob == 0: return (output_tokens, masked_lm_positions, masked_lm_labels, token_boundary) num_to_predict = min(max_predictions_per_seq, max(1, int(round(len(tokens) * masked_lm_prob)))) # Note(mingdachen): # By default, we set the probilities to favor shorter ngram sequences. ngrams = np.arange(1, max_ngrams + 1, dtype=np.int64) pvals = 1. / np.arange(1, max_ngrams + 1) pvals /= pvals.sum(keepdims=True) if favor_longer_ngram: pvals = pvals[::-1] ngram_indexes = [] for idx in range(len(cand_indexes)): ngram_index = [] for n in ngrams: ngram_index.append(cand_indexes[idx:idx + n]) ngram_indexes.append(ngram_index) np_rng.shuffle(ngram_indexes) masked_lms = [] covered_indexes = set() for cand_index_set in ngram_indexes: if len(masked_lms) >= num_to_predict: break if not cand_index_set: continue # Note(mingdachen): # Skip current piece if they are covered in lm masking or previous ngrams. for index_set in cand_index_set[0]: for index in index_set: if index in covered_indexes: continue n = np_rng.choice(ngrams[:len(cand_index_set)], p=pvals[:len(cand_index_set)] / pvals[:len(cand_index_set)].sum(keepdims=True)) index_set = sum(cand_index_set[n - 1], []) n -= 1 # Note(mingdachen): # Repeatedly looking for a candidate that does not exceed the # maximum number of predictions by trying shorter ngrams. while len(masked_lms) + len(index_set) > num_to_predict: if n == 0: break index_set = sum(cand_index_set[n - 1], []) n -= 1 # If adding a whole-word mask would exceed the maximum number of # predictions, then just skip this candidate. if len(masked_lms) + len(index_set) > num_to_predict: continue is_any_index_covered = False for index in index_set: if index in covered_indexes: is_any_index_covered = True break if is_any_index_covered: continue for index in index_set: covered_indexes.add(index) masked_token = None # 80% of the time, replace with [MASK] if np_rng.random() < 0.8: masked_token = mask_id else: # 10% of the time, keep original if np_rng.random() < 0.5: masked_token = tokens[index] # 10% of the time, replace with random word else: masked_token = vocab_id_list[np_rng.randint(0, len(vocab_id_list))] output_tokens[index] = masked_token masked_lms.append(MaskedLmInstance(index=index, label=tokens[index])) assert len(masked_lms) <= num_to_predict np_rng.shuffle(ngram_indexes) select_indexes = set() if do_permutation: for cand_index_set in ngram_indexes: if len(select_indexes) >= num_to_predict: break if not cand_index_set: continue # Note(mingdachen): # Skip current piece if they are covered in lm masking or previous ngrams. for index_set in cand_index_set[0]: for index in index_set: if index in covered_indexes or index in select_indexes: continue n = np.random.choice(ngrams[:len(cand_index_set)], p=pvals[:len(cand_index_set)] / pvals[:len(cand_index_set)].sum(keepdims=True)) index_set = sum(cand_index_set[n - 1], []) n -= 1 while len(select_indexes) + len(index_set) > num_to_predict: if n == 0: break index_set = sum(cand_index_set[n - 1], []) n -= 1 # If adding a whole-word mask would exceed the maximum number of # predictions, then just skip this candidate. if len(select_indexes) + len(index_set) > num_to_predict: continue is_any_index_covered = False for index in index_set: if index in covered_indexes or index in select_indexes: is_any_index_covered = True break if is_any_index_covered: continue for index in index_set: select_indexes.add(index) assert len(select_indexes) <= num_to_predict select_indexes = sorted(select_indexes) permute_indexes = list(select_indexes) np_rng.shuffle(permute_indexes) orig_token = list(output_tokens) for src_i, tgt_i in zip(select_indexes, permute_indexes): output_tokens[src_i] = orig_token[tgt_i] masked_lms.append(MaskedLmInstance(index=src_i, label=orig_token[src_i])) masked_lms = sorted(masked_lms, key=lambda x: x.index) for p in masked_lms: masked_lm_positions.append(p.index) masked_lm_labels.append(p.label) return (output_tokens, masked_lm_positions, masked_lm_labels, token_boundary) def pad_and_convert_to_numpy(tokens, tokentypes, masked_positions, masked_labels, pad_id, max_seq_length): """Pad sequences and convert them to numpy.""" # Some checks. num_tokens = len(tokens) padding_length = max_seq_length - num_tokens assert padding_length >= 0 assert len(tokentypes) == num_tokens assert len(masked_positions) == len(masked_labels) # Tokens and token types. filler = [pad_id] * padding_length tokens_np = np.array(tokens + filler, dtype=np.int64) tokentypes_np = np.array(tokentypes + filler, dtype=np.int64) # Padding mask. padding_mask_np = np.array([1] * num_tokens + [0] * padding_length, dtype=np.int64) # Lables and loss mask. labels = [-1] * max_seq_length loss_mask = [0] * max_seq_length for i in range(len(masked_positions)): assert masked_positions[i] < num_tokens labels[masked_positions[i]] = masked_labels[i] loss_mask[masked_positions[i]] = 1 labels_np = np.array(labels, dtype=np.int64) loss_mask_np = np.array(loss_mask, dtype=np.int64) return tokens_np, tokentypes_np, labels_np, padding_mask_np, loss_mask_np