Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
adff6ed3
Commit
adff6ed3
authored
Mar 23, 2021
by
Hongkun Yu
Committed by
A. Unique TensorFlower
Mar 23, 2021
Browse files
Open source bert pretrain dataloader with dynamic sequence handling.
PiperOrigin-RevId: 364634581
parent
fd2a02af
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
499 additions
and
24 deletions
+499
-24
official/nlp/configs/pretraining_experiments.py
official/nlp/configs/pretraining_experiments.py
+46
-24
official/nlp/data/pretrain_dynamic_dataloader.py
official/nlp/data/pretrain_dynamic_dataloader.py
+211
-0
official/nlp/data/pretrain_dynamic_dataloader_test.py
official/nlp/data/pretrain_dynamic_dataloader_test.py
+242
-0
No files found.
official/nlp/configs/pretraining_experiments.py
View file @
adff6ed3
...
...
@@ -18,8 +18,34 @@ from official.core import config_definitions as cfg
from
official.core
import
exp_factory
from
official.modeling
import
optimization
from
official.nlp.data
import
pretrain_dataloader
from
official.nlp.data
import
pretrain_dynamic_dataloader
from
official.nlp.tasks
import
masked_lm
_TRAINER
=
cfg
.
TrainerConfig
(
train_steps
=
1000000
,
optimizer_config
=
optimization
.
OptimizationConfig
({
'optimizer'
:
{
'type'
:
'adamw'
,
'adamw'
:
{
'weight_decay_rate'
:
0.01
,
'exclude_from_weight_decay'
:
[
'LayerNorm'
,
'layer_norm'
,
'bias'
],
}
},
'learning_rate'
:
{
'type'
:
'polynomial'
,
'polynomial'
:
{
'initial_learning_rate'
:
1e-4
,
'end_learning_rate'
:
0.0
,
}
},
'warmup'
:
{
'type'
:
'polynomial'
}
}))
@
exp_factory
.
register_config_factory
(
'bert/pretraining'
)
def
bert_pretraining
()
->
cfg
.
ExperimentConfig
:
...
...
@@ -29,30 +55,26 @@ def bert_pretraining() -> cfg.ExperimentConfig:
train_data
=
pretrain_dataloader
.
BertPretrainDataConfig
(),
validation_data
=
pretrain_dataloader
.
BertPretrainDataConfig
(
is_training
=
False
)),
trainer
=
cfg
.
TrainerConfig
(
train_steps
=
1000000
,
optimizer_config
=
optimization
.
OptimizationConfig
({
'optimizer'
:
{
'type'
:
'adamw'
,
'adamw'
:
{
'weight_decay_rate'
:
0.01
,
'exclude_from_weight_decay'
:
[
'LayerNorm'
,
'layer_norm'
,
'bias'
],
}
},
'learning_rate'
:
{
'type'
:
'polynomial'
,
'polynomial'
:
{
'initial_learning_rate'
:
1e-4
,
'end_learning_rate'
:
0.0
,
}
},
'warmup'
:
{
'type'
:
'polynomial'
}
})),
trainer
=
_TRAINER
,
restrictions
=
[
'task.train_data.is_training != None'
,
'task.validation_data.is_training != None'
])
return
config
@
exp_factory
.
register_config_factory
(
'bert/pretraining_dynamic'
)
def
bert_dynamic
()
->
cfg
.
ExperimentConfig
:
"""BERT base with dynamic input sequences.
TPU needs to run with tf.data service with round-robin behavior.
"""
config
=
cfg
.
ExperimentConfig
(
task
=
masked_lm
.
MaskedLMConfig
(
train_data
=
pretrain_dynamic_dataloader
.
BertPretrainDataConfig
(),
validation_data
=
pretrain_dataloader
.
BertPretrainDataConfig
(
is_training
=
False
)),
trainer
=
_TRAINER
,
restrictions
=
[
'task.train_data.is_training != None'
,
'task.validation_data.is_training != None'
...
...
official/nlp/data/pretrain_dynamic_dataloader.py
0 → 100644
View file @
adff6ed3
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Dataset loader for the pre-training with dynamic sequence length."""
from
typing
import
Optional
,
Tuple
import
dataclasses
import
tensorflow
as
tf
from
official.core
import
config_definitions
as
cfg
from
official.core
import
input_reader
from
official.nlp.data
import
data_loader_factory
from
official.nlp.data
import
pretrain_dataloader
@
dataclasses
.
dataclass
class
BertPretrainDataConfig
(
cfg
.
DataConfig
):
"""Data config for BERT pretraining task (tasks/masked_lm)."""
input_path
:
str
=
''
global_batch_size
:
int
=
512
is_training
:
bool
=
True
seq_bucket_lengths
:
Tuple
[
int
,
...]
=
(
128
,
256
,
384
,
512
,)
# TODO(rxsang): `seq_bucket_window_scale` is only useful when round robin
# tf.data service is disabled. Deprecate this flag once we always enable round
# robin tf.data service.
seq_bucket_window_scale
:
int
=
8
use_next_sentence_label
:
bool
=
True
use_position_id
:
bool
=
False
deterministic
:
bool
=
False
enable_tf_data_service
:
bool
=
False
enable_round_robin_tf_data_service
:
bool
=
False
tf_data_service_job_name
:
str
=
'bert_pretrain'
use_v2_feature_names
:
bool
=
False
@
data_loader_factory
.
register_data_loader_cls
(
BertPretrainDataConfig
)
class
PretrainingDynamicDataLoader
(
pretrain_dataloader
.
BertPretrainDataLoader
):
"""Dataset loader for bert-style pretraining with dynamic sequenece length.
Bucketizes the input id features by the seq_bucket_lengths and features are
padded to the bucket boundaries. The mask features are usually short than
input id features and can also be dynamic. We require the mask feature lengths
within a bucket must be the same. For example, with [128, 256] buckets,
the mask features for bucket 128 should always have the length as X and
features for bucket 256 should always have the length as Y.
The dataloader does not filter out empty masks. Make sure to handle this
in the model.
"""
def
__init__
(
self
,
params
):
self
.
_params
=
params
if
len
(
params
.
seq_bucket_lengths
)
<
1
:
raise
ValueError
(
'The seq_bucket_lengths cannot be empty.'
)
self
.
_seq_bucket_lengths
=
params
.
seq_bucket_lengths
self
.
_seq_bucket_window_scale
=
params
.
seq_bucket_window_scale
self
.
_global_batch_size
=
params
.
global_batch_size
self
.
_use_next_sentence_label
=
params
.
use_next_sentence_label
self
.
_use_position_id
=
params
.
use_position_id
self
.
_drop_remainder
=
params
.
drop_remainder
self
.
_enable_tf_data_service
=
params
.
enable_tf_data_service
self
.
_enable_round_robin_tf_data_service
=
(
params
.
enable_round_robin_tf_data_service
)
self
.
_mask_keys
=
[
'masked_lm_positions'
,
'masked_lm_ids'
,
'masked_lm_weights'
]
def
_decode
(
self
,
record
:
tf
.
Tensor
):
"""Decodes a serialized tf.Example."""
name_to_features
=
{
'input_ids'
:
tf
.
io
.
VarLenFeature
(
tf
.
int64
),
'input_mask'
:
tf
.
io
.
VarLenFeature
(
tf
.
int64
),
'segment_ids'
:
tf
.
io
.
VarLenFeature
(
tf
.
int64
),
'masked_lm_positions'
:
tf
.
io
.
VarLenFeature
(
tf
.
int64
),
'masked_lm_ids'
:
tf
.
io
.
VarLenFeature
(
tf
.
int64
),
'masked_lm_weights'
:
tf
.
io
.
VarLenFeature
(
tf
.
float32
),
}
if
self
.
_use_next_sentence_label
:
name_to_features
[
'next_sentence_labels'
]
=
tf
.
io
.
FixedLenFeature
([
1
],
tf
.
int64
)
dynamic_keys
=
[
'input_ids'
,
'input_mask'
,
'segment_ids'
]
if
self
.
_use_position_id
:
name_to_features
[
'position_ids'
]
=
tf
.
io
.
VarLenFeature
(
tf
.
int64
)
dynamic_keys
.
append
(
'position_ids'
)
example
=
tf
.
io
.
parse_single_example
(
record
,
name_to_features
)
for
key
in
dynamic_keys
+
self
.
_mask_keys
:
example
[
key
]
=
tf
.
sparse
.
to_dense
(
example
[
key
])
# Truncate padded data after the first non pad in the
# sequence length dimension.
# Pad before the first non pad from the back should not be removed.
mask
=
tf
.
math
.
greater
(
tf
.
math
.
cumsum
(
example
[
'input_ids'
],
reverse
=
True
),
0
)
for
key
in
dynamic_keys
:
example
[
key
]
=
tf
.
boolean_mask
(
example
[
key
],
mask
)
# masked_lm_ids should be 0 padded.
# Change mask features to -1 padding so that we can differentiate
# padding from data or from bucketizing.
mask
=
tf
.
math
.
not_equal
(
example
[
'masked_lm_ids'
],
0
)
example
[
'masked_lm_ids'
]
=
tf
.
where
(
mask
,
example
[
'masked_lm_ids'
],
-
tf
.
ones
(
tf
.
shape
(
example
[
'masked_lm_ids'
]),
dtype
=
example
[
key
].
dtype
))
# tf.Example only supports tf.int64, but the TPU only supports tf.int32.
# So cast all int64 to int32.
# tf.data service uses dataset graph fingerprint to distinguish input
# pipeline jobs, thus we sort the keys here to make sure they are generated
# in a deterministic order each time the dataset function is traced.
for
name
in
sorted
(
list
(
example
.
keys
())):
t
=
example
[
name
]
if
t
.
dtype
==
tf
.
int64
:
t
=
tf
.
cast
(
t
,
tf
.
int32
)
example
[
name
]
=
t
return
example
def
_bucketize_and_batch
(
self
,
dataset
,
input_context
:
Optional
[
tf
.
distribute
.
InputContext
]
=
None
):
"""Bucketize by sequence length and batch the datasets."""
per_replica_batch_size
=
input_context
.
get_per_replica_batch_size
(
self
.
_global_batch_size
)
if
input_context
else
self
.
_global_batch_size
def
element_length_func
(
example
,
seq_len_dim
):
return
tf
.
shape
(
example
[
'input_word_ids'
])[
seq_len_dim
]
bucket_boundaries
=
[
length
+
1
for
length
in
self
.
_seq_bucket_lengths
]
bucket_batch_sizes
=
[
per_replica_batch_size
]
*
(
len
(
bucket_boundaries
)
+
1
)
# Bucketize and batch the dataset with per replica batch size first.
dataset
=
dataset
.
apply
(
tf
.
data
.
experimental
.
bucket_by_sequence_length
(
lambda
example
:
tf
.
cast
(
element_length_func
(
example
,
0
),
tf
.
int32
),
bucket_boundaries
,
bucket_batch_sizes
,
pad_to_bucket_boundary
=
True
,
drop_remainder
=
self
.
_drop_remainder
))
if
input_context
:
window_size
=
input_context
.
num_replicas_in_sync
if
self
.
_enable_tf_data_service
and
(
not
self
.
_enable_round_robin_tf_data_service
):
# If tf.data service is enabled but round-robin behavior is not enabled,
# different TPU workers may fetch data from one tf.data service worker
# in different speed. We set the window size to be
# `seq_bucket_window_scale` larger to leave buffer if some workers are
# fetching data faster than others, so all the data within the same
# global batch can still have more chances to be in the same bucket.
window_size
*=
self
.
_seq_bucket_window_scale
# Group `num_replicas_in_sync` batches from same bucket together, so all
# replicas can get the same sequence length for one global step.
dataset
=
dataset
.
apply
(
tf
.
data
.
experimental
.
group_by_window
(
key_func
=
lambda
example
:
tf
.
cast
(
# pylint: disable=g-long-lambda
element_length_func
(
example
,
1
),
tf
.
int64
),
reduce_func
=
lambda
_
,
x
:
tf
.
data
.
Dataset
.
from_tensors
(
x
),
window_size
=
window_size
))
dataset
=
dataset
.
flat_map
(
lambda
x
:
x
)
def
_remove_pads_from_bucketize
(
features
):
# All mask features must have the same effective length.
# The real masked ids padding token is -1 and 0 comes from
# bucket_by_sequence_length.
mask
=
tf
.
math
.
not_equal
(
features
[
'masked_lm_ids'
],
0
)
mask_per_example
=
tf
.
math
.
reduce_sum
(
tf
.
cast
(
mask
,
tf
.
int32
),
axis
=
1
)
normalized
=
tf
.
cast
(
mask_per_example
/
tf
.
math
.
reduce_max
(
mask_per_example
),
tf
.
int32
)
assert_op
=
tf
.
debugging
.
assert_equal
(
tf
.
math
.
reduce_sum
(
normalized
),
per_replica_batch_size
,
'Number of non padded mask tokens is not the same for each example '
'in the same sequence length.'
)
with
tf
.
control_dependencies
([
assert_op
]):
for
key
in
self
.
_mask_keys
:
features
[
key
]
=
tf
.
reshape
(
tf
.
boolean_mask
(
features
[
key
],
mask
),
[
per_replica_batch_size
,
-
1
])
# Revert masked_lm_ids to be 0-padded.
mask
=
tf
.
math
.
not_equal
(
features
[
'masked_lm_ids'
],
-
1
)
features
[
'masked_lm_ids'
]
=
tf
.
where
(
mask
,
features
[
'masked_lm_ids'
],
tf
.
zeros
(
tf
.
shape
(
features
[
'masked_lm_ids'
]),
dtype
=
features
[
'masked_lm_ids'
].
dtype
))
return
features
dataset
=
dataset
.
map
(
_remove_pads_from_bucketize
)
return
dataset
def
load
(
self
,
input_context
:
Optional
[
tf
.
distribute
.
InputContext
]
=
None
):
"""Returns a tf.dataset.Dataset."""
reader
=
input_reader
.
InputReader
(
params
=
self
.
_params
,
decoder_fn
=
self
.
_decode
,
parser_fn
=
self
.
_parse
,
transform_and_batch_fn
=
self
.
_bucketize_and_batch
)
return
reader
.
read
(
input_context
)
official/nlp/data/pretrain_dynamic_dataloader_test.py
0 → 100644
View file @
adff6ed3
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for nlp.data.pretrain_dynamic_dataloader."""
import
os
from
absl
import
logging
from
absl.testing
import
parameterized
import
numpy
as
np
import
orbit
import
tensorflow
as
tf
from
tensorflow.python.distribute
import
combinations
from
tensorflow.python.distribute
import
strategy_combinations
from
official.nlp.configs
import
bert
from
official.nlp.configs
import
encoders
from
official.nlp.data
import
pretrain_dataloader
from
official.nlp.data
import
pretrain_dynamic_dataloader
from
official.nlp.tasks
import
masked_lm
def
_create_fake_dataset
(
output_path
,
seq_length
,
num_masked_tokens
,
max_seq_length
,
num_examples
):
"""Creates a fake dataset."""
writer
=
tf
.
io
.
TFRecordWriter
(
output_path
)
def
create_int_feature
(
values
):
f
=
tf
.
train
.
Feature
(
int64_list
=
tf
.
train
.
Int64List
(
value
=
list
(
values
)))
return
f
def
create_float_feature
(
values
):
f
=
tf
.
train
.
Feature
(
float_list
=
tf
.
train
.
FloatList
(
value
=
list
(
values
)))
return
f
for
_
in
range
(
num_examples
):
features
=
{}
padding
=
np
.
zeros
(
shape
=
(
max_seq_length
-
seq_length
),
dtype
=
np
.
int32
)
input_ids
=
np
.
random
.
randint
(
low
=
1
,
high
=
100
,
size
=
(
seq_length
))
features
[
'input_ids'
]
=
create_int_feature
(
np
.
concatenate
((
input_ids
,
padding
)))
features
[
'input_mask'
]
=
create_int_feature
(
np
.
concatenate
((
np
.
ones_like
(
input_ids
),
padding
)))
features
[
'segment_ids'
]
=
create_int_feature
(
np
.
concatenate
((
np
.
ones_like
(
input_ids
),
padding
)))
features
[
'position_ids'
]
=
create_int_feature
(
np
.
concatenate
((
np
.
ones_like
(
input_ids
),
padding
)))
features
[
'masked_lm_positions'
]
=
create_int_feature
(
np
.
random
.
randint
(
60
,
size
=
(
num_masked_tokens
),
dtype
=
np
.
int64
))
features
[
'masked_lm_ids'
]
=
create_int_feature
(
np
.
random
.
randint
(
100
,
size
=
(
num_masked_tokens
),
dtype
=
np
.
int64
))
features
[
'masked_lm_weights'
]
=
create_float_feature
(
np
.
ones
((
num_masked_tokens
,),
dtype
=
np
.
float32
))
features
[
'next_sentence_labels'
]
=
create_int_feature
(
np
.
array
([
0
]))
tf_example
=
tf
.
train
.
Example
(
features
=
tf
.
train
.
Features
(
feature
=
features
))
writer
.
write
(
tf_example
.
SerializeToString
())
writer
.
close
()
class
PretrainDynamicDataLoaderTest
(
tf
.
test
.
TestCase
,
parameterized
.
TestCase
):
@
combinations
.
generate
(
combinations
.
combine
(
distribution_strategy
=
[
strategy_combinations
.
cloud_tpu_strategy
,
],
mode
=
'eager'
))
def
test_distribution_strategy
(
self
,
distribution_strategy
):
max_seq_length
=
128
batch_size
=
8
input_path
=
os
.
path
.
join
(
self
.
get_temp_dir
(),
'train.tf_record'
)
_create_fake_dataset
(
input_path
,
seq_length
=
60
,
num_masked_tokens
=
20
,
max_seq_length
=
max_seq_length
,
num_examples
=
batch_size
)
data_config
=
pretrain_dynamic_dataloader
.
BertPretrainDataConfig
(
is_training
=
False
,
input_path
=
input_path
,
seq_bucket_lengths
=
[
64
,
128
],
global_batch_size
=
batch_size
)
dataloader
=
pretrain_dynamic_dataloader
.
PretrainingDynamicDataLoader
(
data_config
)
distributed_ds
=
orbit
.
utils
.
make_distributed_dataset
(
distribution_strategy
,
dataloader
.
load
)
train_iter
=
iter
(
distributed_ds
)
with
distribution_strategy
.
scope
():
config
=
masked_lm
.
MaskedLMConfig
(
init_checkpoint
=
self
.
get_temp_dir
(),
model
=
bert
.
PretrainerConfig
(
encoders
.
EncoderConfig
(
bert
=
encoders
.
BertEncoderConfig
(
vocab_size
=
30522
,
num_layers
=
1
)),
cls_heads
=
[
bert
.
ClsHeadConfig
(
inner_dim
=
10
,
num_classes
=
2
,
name
=
'next_sentence'
)
]),
train_data
=
data_config
)
task
=
masked_lm
.
MaskedLMTask
(
config
)
model
=
task
.
build_model
()
metrics
=
task
.
build_metrics
()
@
tf
.
function
def
step_fn
(
features
):
return
task
.
validation_step
(
features
,
model
,
metrics
=
metrics
)
distributed_outputs
=
distribution_strategy
.
run
(
step_fn
,
args
=
(
next
(
train_iter
),))
local_results
=
tf
.
nest
.
map_structure
(
distribution_strategy
.
experimental_local_results
,
distributed_outputs
)
logging
.
info
(
'Dynamic padding: local_results= %s'
,
str
(
local_results
))
dynamic_metrics
=
{}
for
metric
in
metrics
:
dynamic_metrics
[
metric
.
name
]
=
metric
.
result
()
data_config
=
pretrain_dataloader
.
BertPretrainDataConfig
(
is_training
=
False
,
input_path
=
input_path
,
seq_length
=
max_seq_length
,
max_predictions_per_seq
=
20
,
global_batch_size
=
batch_size
)
dataloader
=
pretrain_dataloader
.
BertPretrainDataLoader
(
data_config
)
distributed_ds
=
orbit
.
utils
.
make_distributed_dataset
(
distribution_strategy
,
dataloader
.
load
)
train_iter
=
iter
(
distributed_ds
)
with
distribution_strategy
.
scope
():
metrics
=
task
.
build_metrics
()
@
tf
.
function
def
step_fn_b
(
features
):
return
task
.
validation_step
(
features
,
model
,
metrics
=
metrics
)
distributed_outputs
=
distribution_strategy
.
run
(
step_fn_b
,
args
=
(
next
(
train_iter
),))
local_results
=
tf
.
nest
.
map_structure
(
distribution_strategy
.
experimental_local_results
,
distributed_outputs
)
logging
.
info
(
'Static padding: local_results= %s'
,
str
(
local_results
))
static_metrics
=
{}
for
metric
in
metrics
:
static_metrics
[
metric
.
name
]
=
metric
.
result
()
for
key
in
static_metrics
:
# We need to investigate the differences on losses.
if
key
!=
'next_sentence_loss'
:
self
.
assertEqual
(
dynamic_metrics
[
key
],
static_metrics
[
key
])
def
test_load_dataset
(
self
):
max_seq_length
=
128
batch_size
=
2
input_path_1
=
os
.
path
.
join
(
self
.
get_temp_dir
(),
'train_1.tf_record'
)
_create_fake_dataset
(
input_path_1
,
seq_length
=
60
,
num_masked_tokens
=
20
,
max_seq_length
=
max_seq_length
,
num_examples
=
batch_size
)
input_path_2
=
os
.
path
.
join
(
self
.
get_temp_dir
(),
'train_2.tf_record'
)
_create_fake_dataset
(
input_path_2
,
seq_length
=
100
,
num_masked_tokens
=
70
,
max_seq_length
=
max_seq_length
,
num_examples
=
batch_size
)
input_paths
=
','
.
join
([
input_path_1
,
input_path_2
])
data_config
=
pretrain_dynamic_dataloader
.
BertPretrainDataConfig
(
is_training
=
False
,
input_path
=
input_paths
,
seq_bucket_lengths
=
[
64
,
128
],
use_position_id
=
True
,
global_batch_size
=
batch_size
)
dataset
=
pretrain_dynamic_dataloader
.
PretrainingDynamicDataLoader
(
data_config
).
load
()
dataset_it
=
iter
(
dataset
)
features
=
next
(
dataset_it
)
self
.
assertCountEqual
([
'input_word_ids'
,
'input_mask'
,
'input_type_ids'
,
'next_sentence_labels'
,
'masked_lm_positions'
,
'masked_lm_ids'
,
'masked_lm_weights'
,
'position_ids'
,
],
features
.
keys
())
# Sequence length dimension should be bucketized and pad to 64.
self
.
assertEqual
(
features
[
'input_word_ids'
].
shape
,
(
batch_size
,
64
))
self
.
assertEqual
(
features
[
'input_mask'
].
shape
,
(
batch_size
,
64
))
self
.
assertEqual
(
features
[
'input_type_ids'
].
shape
,
(
batch_size
,
64
))
self
.
assertEqual
(
features
[
'position_ids'
].
shape
,
(
batch_size
,
64
))
self
.
assertEqual
(
features
[
'masked_lm_positions'
].
shape
,
(
batch_size
,
20
))
features
=
next
(
dataset_it
)
self
.
assertEqual
(
features
[
'input_word_ids'
].
shape
,
(
batch_size
,
128
))
self
.
assertEqual
(
features
[
'input_mask'
].
shape
,
(
batch_size
,
128
))
self
.
assertEqual
(
features
[
'input_type_ids'
].
shape
,
(
batch_size
,
128
))
self
.
assertEqual
(
features
[
'position_ids'
].
shape
,
(
batch_size
,
128
))
self
.
assertEqual
(
features
[
'masked_lm_positions'
].
shape
,
(
batch_size
,
70
))
def
test_load_dataset_not_same_masks
(
self
):
max_seq_length
=
128
batch_size
=
2
input_path_1
=
os
.
path
.
join
(
self
.
get_temp_dir
(),
'train_3.tf_record'
)
_create_fake_dataset
(
input_path_1
,
seq_length
=
60
,
num_masked_tokens
=
20
,
max_seq_length
=
max_seq_length
,
num_examples
=
batch_size
)
input_path_2
=
os
.
path
.
join
(
self
.
get_temp_dir
(),
'train_4.tf_record'
)
_create_fake_dataset
(
input_path_2
,
seq_length
=
60
,
num_masked_tokens
=
15
,
max_seq_length
=
max_seq_length
,
num_examples
=
batch_size
)
input_paths
=
','
.
join
([
input_path_1
,
input_path_2
])
data_config
=
pretrain_dynamic_dataloader
.
BertPretrainDataConfig
(
is_training
=
False
,
input_path
=
input_paths
,
seq_bucket_lengths
=
[
64
,
128
],
use_position_id
=
True
,
global_batch_size
=
batch_size
*
2
)
dataset
=
pretrain_dynamic_dataloader
.
PretrainingDynamicDataLoader
(
data_config
).
load
()
dataset_it
=
iter
(
dataset
)
with
self
.
assertRaisesRegex
(
tf
.
errors
.
InvalidArgumentError
,
'.*Number of non padded mask tokens.*'
):
next
(
dataset_it
)
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment