Commit e2add0fd authored by Mohammad Shoeybi's avatar Mohammad Shoeybi
Browse files

resolved reproducibity issue

parent 0601702a
......@@ -137,9 +137,6 @@ class AlbertDataset(Dataset):
def __getitem__(self, idx):
# Note that this rng state should be python and not numpy since
# python randint is inclusive whereas the numpy one is exclusive.
rng = random.Random(self.seed + idx)
start_index, end_index, seq_length = self.samples_mapping[idx]
sample = []
for index in range(start_index, end_index):
......@@ -149,13 +146,16 @@ class AlbertDataset(Dataset):
if len(s) > 1000:
print(self.tokenizer.convert_ids_to_tokens(s))
'''
# Note that this rng state should be numpy and not python since
# python randint is inclusive whereas the numpy one is exclusive.
np_rng = np.random.RandomState(seed=(self.seed + idx))
return build_training_sample(sample, seq_length,
self.max_seq_length, # needed for padding
self.vocab_id_list,
self.vocab_id_to_token_dict,
self.cls_id, self.sep_id,
self.mask_id, self.pad_id,
self.masked_lm_prob, rng)
self.masked_lm_prob, np_rng)
def get_indexed_dataset_(data_prefix, data_impl, skip_warmup):
......
......@@ -9,7 +9,7 @@ def build_training_sample(sample,
target_seq_length, max_seq_length,
vocab_id_list, vocab_id_to_token_dict,
cls_id, sep_id, mask_id, pad_id,
masked_lm_prob, rng):
masked_lm_prob, np_rng):
"""Biuld training sample.
Arguments:
......@@ -24,8 +24,8 @@ def build_training_sample(sample,
mask_id: Mask token id.
pad_id: Padding token id.
masked_lm_prob: Probability to mask tokens.
rng: Random number genenrator. Note that this rng state should be
python and not numpy since python randint is inclusive for
np_rng: Random number genenrator. Note that this rng state should be
numpy and not python since python randint is inclusive for
the opper bound whereas the numpy one is exclusive.
"""
......@@ -34,12 +34,12 @@ def build_training_sample(sample,
assert target_seq_length <= max_seq_length
# Divide sample into two segments (A and B).
tokens_a, tokens_b, is_next_random = get_a_and_b_segments(sample, rng)
tokens_a, tokens_b, is_next_random = get_a_and_b_segments(sample, np_rng)
# Truncate to `target_sequence_length`.
max_num_tokens = target_seq_length
truncated = truncate_segments(tokens_a, tokens_b, len(tokens_a),
len(tokens_b), max_num_tokens, rng)
len(tokens_b), max_num_tokens, np_rng)
# Build tokens and toketypes.
tokens, tokentypes = create_tokens_and_tokentypes(tokens_a, tokens_b,
......@@ -49,7 +49,7 @@ def build_training_sample(sample,
max_predictions_per_seq = masked_lm_prob * max_num_tokens
(tokens, masked_positions, masked_labels, _) = create_masked_lm_predictions(
tokens, vocab_id_list, vocab_id_to_token_dict, masked_lm_prob,
cls_id, sep_id, mask_id, max_predictions_per_seq, rng)
cls_id, sep_id, mask_id, max_predictions_per_seq, np_rng)
# Padding.
tokens_np, tokentypes_np, labels_np, padding_mask_np, loss_mask_np \
......@@ -67,7 +67,7 @@ def build_training_sample(sample,
return train_sample
def get_a_and_b_segments(sample, rng):
def get_a_and_b_segments(sample, np_rng):
"""Divide sample into a and b segments."""
# Number of sentences in the sample.
......@@ -79,8 +79,8 @@ def get_a_and_b_segments(sample, rng):
# `a_end` is how many sentences go into the `A`.
a_end = 1
if n_sentences >= 3:
# Note that randin in python is inclusive.
a_end = rng.randint(1, n_sentences - 1)
# Note that randin in numpy is exclusive.
a_end = np_rng.randint(1, n_sentences)
tokens_a = []
for j in range(a_end):
tokens_a.extend(sample[j])
......@@ -92,14 +92,14 @@ def get_a_and_b_segments(sample, rng):
# Random next:
is_next_random = False
if rng.random() < 0.5:
if np_rng.random() < 0.5:
is_next_random = True
tokens_a, tokens_b = tokens_b, tokens_a
return tokens_a, tokens_b, is_next_random
def truncate_segments(tokens_a, tokens_b, len_a, len_b, max_num_tokens, rng):
def truncate_segments(tokens_a, tokens_b, len_a, len_b, max_num_tokens, np_rng):
"""Truncates a pair of sequences to a maximum sequence length."""
#print(len_a, len_b, max_num_tokens)
assert len_a > 0
......@@ -113,7 +113,7 @@ def truncate_segments(tokens_a, tokens_b, len_a, len_b, max_num_tokens, rng):
else:
len_b -= 1
tokens = tokens_b
if rng.random() < 0.5:
if np_rng.random() < 0.5:
del tokens[0]
else:
tokens.pop()
......@@ -163,7 +163,7 @@ def create_masked_lm_predictions(tokens,
masked_lm_prob,
cls_id, sep_id, mask_id,
max_predictions_per_seq,
rng,
np_rng,
max_ngrams=3,
do_whole_word_mask=True,
favor_longer_ngram=False,
......@@ -223,7 +223,7 @@ def create_masked_lm_predictions(tokens,
ngram_index.append(cand_indexes[idx:idx+n])
ngram_indexes.append(ngram_index)
rng.shuffle(ngram_indexes)
np_rng.shuffle(ngram_indexes)
masked_lms = []
covered_indexes = set()
......@@ -239,9 +239,9 @@ def create_masked_lm_predictions(tokens,
if index in covered_indexes:
continue
n = np.random.choice(ngrams[:len(cand_index_set)],
p=pvals[:len(cand_index_set)] /
pvals[:len(cand_index_set)].sum(keepdims=True))
n = np_rng.choice(ngrams[:len(cand_index_set)],
p=pvals[:len(cand_index_set)] /
pvals[:len(cand_index_set)].sum(keepdims=True))
index_set = sum(cand_index_set[n - 1], [])
n -= 1
# Note(mingdachen):
......@@ -268,22 +268,22 @@ def create_masked_lm_predictions(tokens,
masked_token = None
# 80% of the time, replace with [MASK]
if rng.random() < 0.8:
if np_rng.random() < 0.8:
masked_token = mask_id
else:
# 10% of the time, keep original
if rng.random() < 0.5:
if np_rng.random() < 0.5:
masked_token = tokens[index]
# 10% of the time, replace with random word
else:
masked_token = vocab_id_list[rng.randint(0, len(vocab_id_list) - 1)]
masked_token = vocab_id_list[np_rng.randint(0, len(vocab_id_list))]
output_tokens[index] = masked_token
masked_lms.append(MaskedLmInstance(index=index, label=tokens[index]))
assert len(masked_lms) <= num_to_predict
rng.shuffle(ngram_indexes)
np_rng.shuffle(ngram_indexes)
select_indexes = set()
if do_permutation:
......@@ -327,7 +327,7 @@ def create_masked_lm_predictions(tokens,
select_indexes = sorted(select_indexes)
permute_indexes = list(select_indexes)
rng.shuffle(permute_indexes)
np_rng.shuffle(permute_indexes)
orig_token = list(output_tokens)
for src_i, tgt_i in zip(select_indexes, permute_indexes):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment