Commit 1eccd462 authored by Mohammad Shoeybi's avatar Mohammad Shoeybi
Browse files

Merge branch 'dataset-utils-lint' into 'staging'

Lint megatron/data/dataset_utils.py

See merge request ADLR/megatron-lm!42
parents 63262827 840759b8
...@@ -132,6 +132,7 @@ def truncate_segments(tokens_a, tokens_b, len_a, len_b, max_num_tokens, np_rng): ...@@ -132,6 +132,7 @@ def truncate_segments(tokens_a, tokens_b, len_a, len_b, max_num_tokens, np_rng):
tokens.pop() tokens.pop()
return True return True
def create_tokens_and_tokentypes(tokens_a, tokens_b, cls_id, sep_id): def create_tokens_and_tokentypes(tokens_a, tokens_b, cls_id, sep_id):
"""Merge segments A and B, add [CLS] and [SEP] and build tokentypes.""" """Merge segments A and B, add [CLS] and [SEP] and build tokentypes."""
...@@ -233,7 +234,7 @@ def create_masked_lm_predictions(tokens, ...@@ -233,7 +234,7 @@ def create_masked_lm_predictions(tokens,
for idx in range(len(cand_indexes)): for idx in range(len(cand_indexes)):
ngram_index = [] ngram_index = []
for n in ngrams: for n in ngrams:
ngram_index.append(cand_indexes[idx:idx+n]) ngram_index.append(cand_indexes[idx:idx + n])
ngram_indexes.append(ngram_index) ngram_indexes.append(ngram_index)
np_rng.shuffle(ngram_indexes) np_rng.shuffle(ngram_indexes)
...@@ -367,12 +368,12 @@ def pad_and_convert_to_numpy(tokens, tokentypes, masked_positions, ...@@ -367,12 +368,12 @@ def pad_and_convert_to_numpy(tokens, tokentypes, masked_positions,
assert len(masked_positions) == len(masked_labels) assert len(masked_positions) == len(masked_labels)
# Tokens and token types. # Tokens and token types.
filler = [pad_id]*padding_length filler = [pad_id] * padding_length
tokens_np = np.array(tokens + filler, dtype=np.int64) tokens_np = np.array(tokens + filler, dtype=np.int64)
tokentypes_np = np.array(tokentypes + filler, dtype=np.int64) tokentypes_np = np.array(tokentypes + filler, dtype=np.int64)
# Padding mask. # Padding mask.
padding_mask_np = np.array([1]*num_tokens + [0]*padding_length, padding_mask_np = np.array([1] * num_tokens + [0] * padding_length,
dtype=np.int64) dtype=np.int64)
# Lables and loss mask. # Lables and loss mask.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment