Commit 8b032e91 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 396717391
parent e453835a
...@@ -43,10 +43,11 @@ def _create_fake_dataset(output_path, seq_length, num_masked_tokens, ...@@ -43,10 +43,11 @@ def _create_fake_dataset(output_path, seq_length, num_masked_tokens,
f = tf.train.Feature(float_list=tf.train.FloatList(value=list(values))) f = tf.train.Feature(float_list=tf.train.FloatList(value=list(values)))
return f return f
rng = np.random.default_rng(37)
for _ in range(num_examples): for _ in range(num_examples):
features = {} features = {}
padding = np.zeros(shape=(max_seq_length - seq_length), dtype=np.int32) padding = np.zeros(shape=(max_seq_length - seq_length), dtype=np.int32)
input_ids = np.random.randint(low=1, high=100, size=(seq_length)) input_ids = rng.integers(low=1, high=100, size=(seq_length))
features['input_ids'] = create_int_feature( features['input_ids'] = create_int_feature(
np.concatenate((input_ids, padding))) np.concatenate((input_ids, padding)))
features['input_mask'] = create_int_feature( features['input_mask'] = create_int_feature(
...@@ -56,9 +57,9 @@ def _create_fake_dataset(output_path, seq_length, num_masked_tokens, ...@@ -56,9 +57,9 @@ def _create_fake_dataset(output_path, seq_length, num_masked_tokens,
features['position_ids'] = create_int_feature( features['position_ids'] = create_int_feature(
np.concatenate((np.ones_like(input_ids), padding))) np.concatenate((np.ones_like(input_ids), padding)))
features['masked_lm_positions'] = create_int_feature( features['masked_lm_positions'] = create_int_feature(
np.random.randint(60, size=(num_masked_tokens), dtype=np.int64)) rng.integers(60, size=(num_masked_tokens), dtype=np.int64))
features['masked_lm_ids'] = create_int_feature( features['masked_lm_ids'] = create_int_feature(
np.random.randint(100, size=(num_masked_tokens), dtype=np.int64)) rng.integers(100, size=(num_masked_tokens), dtype=np.int64))
features['masked_lm_weights'] = create_float_feature( features['masked_lm_weights'] = create_float_feature(
np.ones((num_masked_tokens,), dtype=np.float32)) np.ones((num_masked_tokens,), dtype=np.float32))
features['next_sentence_labels'] = create_int_feature(np.array([0])) features['next_sentence_labels'] = create_int_feature(np.array([0]))
...@@ -156,6 +157,7 @@ class PretrainDynamicDataLoaderTest(tf.test.TestCase, parameterized.TestCase): ...@@ -156,6 +157,7 @@ class PretrainDynamicDataLoaderTest(tf.test.TestCase, parameterized.TestCase):
self.assertEqual(dynamic_metrics[key], static_metrics[key]) self.assertEqual(dynamic_metrics[key], static_metrics[key])
def test_load_dataset(self): def test_load_dataset(self):
tf.random.set_seed(0)
max_seq_length = 128 max_seq_length = 128
batch_size = 2 batch_size = 2
input_path_1 = os.path.join(self.get_temp_dir(), 'train_1.tf_record') input_path_1 = os.path.join(self.get_temp_dir(), 'train_1.tf_record')
...@@ -178,7 +180,8 @@ class PretrainDynamicDataLoaderTest(tf.test.TestCase, parameterized.TestCase): ...@@ -178,7 +180,8 @@ class PretrainDynamicDataLoaderTest(tf.test.TestCase, parameterized.TestCase):
input_path=input_paths, input_path=input_paths,
seq_bucket_lengths=[64, 128], seq_bucket_lengths=[64, 128],
use_position_id=True, use_position_id=True,
global_batch_size=batch_size) global_batch_size=batch_size,
deterministic=True)
dataset = pretrain_dynamic_dataloader.PretrainingDynamicDataLoader( dataset = pretrain_dynamic_dataloader.PretrainingDynamicDataLoader(
data_config).load() data_config).load()
dataset_it = iter(dataset) dataset_it = iter(dataset)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment