Commit 1eccd462 authored by Mohammad Shoeybi's avatar Mohammad Shoeybi
Browse files

Merge branch 'dataset-utils-lint' into 'staging'

Lint megatron/data/dataset_utils.py

See merge request ADLR/megatron-lm!42
parents 63262827 840759b8
...@@ -132,6 +132,7 @@ def truncate_segments(tokens_a, tokens_b, len_a, len_b, max_num_tokens, np_rng): ...@@ -132,6 +132,7 @@ def truncate_segments(tokens_a, tokens_b, len_a, len_b, max_num_tokens, np_rng):
tokens.pop() tokens.pop()
return True return True
def create_tokens_and_tokentypes(tokens_a, tokens_b, cls_id, sep_id): def create_tokens_and_tokentypes(tokens_a, tokens_b, cls_id, sep_id):
"""Merge segments A and B, add [CLS] and [SEP] and build tokentypes.""" """Merge segments A and B, add [CLS] and [SEP] and build tokentypes."""
...@@ -163,12 +164,12 @@ MaskedLmInstance = collections.namedtuple("MaskedLmInstance", ...@@ -163,12 +164,12 @@ MaskedLmInstance = collections.namedtuple("MaskedLmInstance",
def is_start_piece(piece): def is_start_piece(piece):
"""Check if the current word piece is the starting piece (BERT).""" """Check if the current word piece is the starting piece (BERT)."""
# When a word has been split into # When a word has been split into
# WordPieces, the first token does not have any marker and any subsequence # WordPieces, the first token does not have any marker and any subsequence
# tokens are prefixed with ##. So whenever we see the ## token, we # tokens are prefixed with ##. So whenever we see the ## token, we
# append it to the previous set of word indexes. # append it to the previous set of word indexes.
return not piece.startswith("##") return not piece.startswith("##")
def create_masked_lm_predictions(tokens, def create_masked_lm_predictions(tokens,
...@@ -181,178 +182,178 @@ def create_masked_lm_predictions(tokens, ...@@ -181,178 +182,178 @@ def create_masked_lm_predictions(tokens,
do_whole_word_mask=True, do_whole_word_mask=True,
favor_longer_ngram=False, favor_longer_ngram=False,
do_permutation=False): do_permutation=False):
"""Creates the predictions for the masked LM objective. """Creates the predictions for the masked LM objective.
Note: Tokens here are vocab ids and not text tokens.""" Note: Tokens here are vocab ids and not text tokens."""
cand_indexes = [] cand_indexes = []
# Note(mingdachen): We create a list for recording if the piece is # Note(mingdachen): We create a list for recording if the piece is
# the starting piece of current token, where 1 means true, so that # the starting piece of current token, where 1 means true, so that
# on-the-fly whole word masking is possible. # on-the-fly whole word masking is possible.
token_boundary = [0] * len(tokens) token_boundary = [0] * len(tokens)
for (i, token) in enumerate(tokens): for (i, token) in enumerate(tokens):
if token == cls_id or token == sep_id: if token == cls_id or token == sep_id:
token_boundary[i] = 1 token_boundary[i] = 1
continue continue
# Whole Word Masking means that if we mask all of the wordpieces # Whole Word Masking means that if we mask all of the wordpieces
# corresponding to an original word. # corresponding to an original word.
# #
# Note that Whole Word Masking does *not* change the training code # Note that Whole Word Masking does *not* change the training code
# at all -- we still predict each WordPiece independently, softmaxed # at all -- we still predict each WordPiece independently, softmaxed
# over the entire vocabulary. # over the entire vocabulary.
if (do_whole_word_mask and len(cand_indexes) >= 1 and if (do_whole_word_mask and len(cand_indexes) >= 1 and
not is_start_piece(vocab_id_to_token_dict[token])): not is_start_piece(vocab_id_to_token_dict[token])):
cand_indexes[-1].append(i) cand_indexes[-1].append(i)
else:
cand_indexes.append([i])
if is_start_piece(vocab_id_to_token_dict[token]):
token_boundary[i] = 1
output_tokens = list(tokens)
masked_lm_positions = []
masked_lm_labels = []
if masked_lm_prob == 0:
return (output_tokens, masked_lm_positions,
masked_lm_labels, token_boundary)
num_to_predict = min(max_predictions_per_seq,
max(1, int(round(len(tokens) * masked_lm_prob))))
# Note(mingdachen):
# By default, we set the probilities to favor shorter ngram sequences.
ngrams = np.arange(1, max_ngrams + 1, dtype=np.int64)
pvals = 1. / np.arange(1, max_ngrams + 1)
pvals /= pvals.sum(keepdims=True)
if favor_longer_ngram:
pvals = pvals[::-1]
ngram_indexes = []
for idx in range(len(cand_indexes)):
ngram_index = []
for n in ngrams:
ngram_index.append(cand_indexes[idx:idx+n])
ngram_indexes.append(ngram_index)
np_rng.shuffle(ngram_indexes)
masked_lms = []
covered_indexes = set()
for cand_index_set in ngram_indexes:
if len(masked_lms) >= num_to_predict:
break
if not cand_index_set:
continue
# Note(mingdachen):
# Skip current piece if they are covered in lm masking or previous ngrams.
for index_set in cand_index_set[0]:
for index in index_set:
if index in covered_indexes:
continue
n = np_rng.choice(ngrams[:len(cand_index_set)],
p=pvals[:len(cand_index_set)] /
pvals[:len(cand_index_set)].sum(keepdims=True))
index_set = sum(cand_index_set[n - 1], [])
n -= 1
# Note(mingdachen):
# Repeatedly looking for a candidate that does not exceed the
# maximum number of predictions by trying shorter ngrams.
while len(masked_lms) + len(index_set) > num_to_predict:
if n == 0:
break
index_set = sum(cand_index_set[n - 1], [])
n -= 1
# If adding a whole-word mask would exceed the maximum number of
# predictions, then just skip this candidate.
if len(masked_lms) + len(index_set) > num_to_predict:
continue
is_any_index_covered = False
for index in index_set:
if index in covered_indexes:
is_any_index_covered = True
break
if is_any_index_covered:
continue
for index in index_set:
covered_indexes.add(index)
masked_token = None
# 80% of the time, replace with [MASK]
if np_rng.random() < 0.8:
masked_token = mask_id
else:
# 10% of the time, keep original
if np_rng.random() < 0.5:
masked_token = tokens[index]
# 10% of the time, replace with random word
else: else:
masked_token = vocab_id_list[np_rng.randint(0, len(vocab_id_list))] cand_indexes.append([i])
if is_start_piece(vocab_id_to_token_dict[token]):
token_boundary[i] = 1
output_tokens[index] = masked_token output_tokens = list(tokens)
masked_lms.append(MaskedLmInstance(index=index, label=tokens[index])) masked_lm_positions = []
assert len(masked_lms) <= num_to_predict masked_lm_labels = []
np_rng.shuffle(ngram_indexes) if masked_lm_prob == 0:
return (output_tokens, masked_lm_positions,
masked_lm_labels, token_boundary)
select_indexes = set() num_to_predict = min(max_predictions_per_seq,
if do_permutation: max(1, int(round(len(tokens) * masked_lm_prob))))
for cand_index_set in ngram_indexes:
if len(select_indexes) >= num_to_predict: # Note(mingdachen):
break # By default, we set the probilities to favor shorter ngram sequences.
if not cand_index_set: ngrams = np.arange(1, max_ngrams + 1, dtype=np.int64)
continue pvals = 1. / np.arange(1, max_ngrams + 1)
# Note(mingdachen): pvals /= pvals.sum(keepdims=True)
# Skip current piece if they are covered in lm masking or previous ngrams.
for index_set in cand_index_set[0]: if favor_longer_ngram:
for index in index_set: pvals = pvals[::-1]
if index in covered_indexes or index in select_indexes:
continue ngram_indexes = []
for idx in range(len(cand_indexes)):
ngram_index = []
for n in ngrams:
ngram_index.append(cand_indexes[idx:idx + n])
ngram_indexes.append(ngram_index)
n = np.random.choice(ngrams[:len(cand_index_set)], np_rng.shuffle(ngram_indexes)
p=pvals[:len(cand_index_set)] /
pvals[:len(cand_index_set)].sum(keepdims=True))
index_set = sum(cand_index_set[n - 1], [])
n -= 1
while len(select_indexes) + len(index_set) > num_to_predict: masked_lms = []
if n == 0: covered_indexes = set()
break for cand_index_set in ngram_indexes:
if len(masked_lms) >= num_to_predict:
break
if not cand_index_set:
continue
# Note(mingdachen):
# Skip current piece if they are covered in lm masking or previous ngrams.
for index_set in cand_index_set[0]:
for index in index_set:
if index in covered_indexes:
continue
n = np_rng.choice(ngrams[:len(cand_index_set)],
p=pvals[:len(cand_index_set)] /
pvals[:len(cand_index_set)].sum(keepdims=True))
index_set = sum(cand_index_set[n - 1], []) index_set = sum(cand_index_set[n - 1], [])
n -= 1 n -= 1
# If adding a whole-word mask would exceed the maximum number of # Note(mingdachen):
# predictions, then just skip this candidate. # Repeatedly looking for a candidate that does not exceed the
if len(select_indexes) + len(index_set) > num_to_predict: # maximum number of predictions by trying shorter ngrams.
continue while len(masked_lms) + len(index_set) > num_to_predict:
is_any_index_covered = False if n == 0:
for index in index_set: break
if index in covered_indexes or index in select_indexes: index_set = sum(cand_index_set[n - 1], [])
is_any_index_covered = True n -= 1
break # If adding a whole-word mask would exceed the maximum number of
if is_any_index_covered: # predictions, then just skip this candidate.
continue if len(masked_lms) + len(index_set) > num_to_predict:
for index in index_set: continue
select_indexes.add(index) is_any_index_covered = False
assert len(select_indexes) <= num_to_predict for index in index_set:
if index in covered_indexes:
select_indexes = sorted(select_indexes) is_any_index_covered = True
permute_indexes = list(select_indexes) break
np_rng.shuffle(permute_indexes) if is_any_index_covered:
orig_token = list(output_tokens) continue
for index in index_set:
for src_i, tgt_i in zip(select_indexes, permute_indexes): covered_indexes.add(index)
output_tokens[src_i] = orig_token[tgt_i]
masked_lms.append(MaskedLmInstance(index=src_i, label=orig_token[src_i])) masked_token = None
# 80% of the time, replace with [MASK]
masked_lms = sorted(masked_lms, key=lambda x: x.index) if np_rng.random() < 0.8:
masked_token = mask_id
for p in masked_lms: else:
masked_lm_positions.append(p.index) # 10% of the time, keep original
masked_lm_labels.append(p.label) if np_rng.random() < 0.5:
return (output_tokens, masked_lm_positions, masked_lm_labels, token_boundary) masked_token = tokens[index]
# 10% of the time, replace with random word
else:
masked_token = vocab_id_list[np_rng.randint(0, len(vocab_id_list))]
output_tokens[index] = masked_token
masked_lms.append(MaskedLmInstance(index=index, label=tokens[index]))
assert len(masked_lms) <= num_to_predict
np_rng.shuffle(ngram_indexes)
select_indexes = set()
if do_permutation:
for cand_index_set in ngram_indexes:
if len(select_indexes) >= num_to_predict:
break
if not cand_index_set:
continue
# Note(mingdachen):
# Skip current piece if they are covered in lm masking or previous ngrams.
for index_set in cand_index_set[0]:
for index in index_set:
if index in covered_indexes or index in select_indexes:
continue
n = np.random.choice(ngrams[:len(cand_index_set)],
p=pvals[:len(cand_index_set)] /
pvals[:len(cand_index_set)].sum(keepdims=True))
index_set = sum(cand_index_set[n - 1], [])
n -= 1
while len(select_indexes) + len(index_set) > num_to_predict:
if n == 0:
break
index_set = sum(cand_index_set[n - 1], [])
n -= 1
# If adding a whole-word mask would exceed the maximum number of
# predictions, then just skip this candidate.
if len(select_indexes) + len(index_set) > num_to_predict:
continue
is_any_index_covered = False
for index in index_set:
if index in covered_indexes or index in select_indexes:
is_any_index_covered = True
break
if is_any_index_covered:
continue
for index in index_set:
select_indexes.add(index)
assert len(select_indexes) <= num_to_predict
select_indexes = sorted(select_indexes)
permute_indexes = list(select_indexes)
np_rng.shuffle(permute_indexes)
orig_token = list(output_tokens)
for src_i, tgt_i in zip(select_indexes, permute_indexes):
output_tokens[src_i] = orig_token[tgt_i]
masked_lms.append(MaskedLmInstance(index=src_i, label=orig_token[src_i]))
masked_lms = sorted(masked_lms, key=lambda x: x.index)
for p in masked_lms:
masked_lm_positions.append(p.index)
masked_lm_labels.append(p.label)
return (output_tokens, masked_lm_positions, masked_lm_labels, token_boundary)
def pad_and_convert_to_numpy(tokens, tokentypes, masked_positions, def pad_and_convert_to_numpy(tokens, tokentypes, masked_positions,
...@@ -367,12 +368,12 @@ def pad_and_convert_to_numpy(tokens, tokentypes, masked_positions, ...@@ -367,12 +368,12 @@ def pad_and_convert_to_numpy(tokens, tokentypes, masked_positions,
assert len(masked_positions) == len(masked_labels) assert len(masked_positions) == len(masked_labels)
# Tokens and token types. # Tokens and token types.
filler = [pad_id]*padding_length filler = [pad_id] * padding_length
tokens_np = np.array(tokens + filler, dtype=np.int64) tokens_np = np.array(tokens + filler, dtype=np.int64)
tokentypes_np = np.array(tokentypes + filler, dtype=np.int64) tokentypes_np = np.array(tokentypes + filler, dtype=np.int64)
# Padding mask. # Padding mask.
padding_mask_np = np.array([1]*num_tokens + [0]*padding_length, padding_mask_np = np.array([1] * num_tokens + [0] * padding_length,
dtype=np.int64) dtype=np.int64)
# Lables and loss mask. # Lables and loss mask.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment