Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
5cc6df63
Commit
5cc6df63
authored
Nov 08, 2021
by
Hongkun Yu
Committed by
A. Unique TensorFlower
Nov 08, 2021
Browse files
Open source dual encoder tasks and dataloaders.
PiperOrigin-RevId: 408397786
parent
e97979cb
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
605 additions
and
0 deletions
+605
-0
official/nlp/data/dual_encoder_dataloader.py
official/nlp/data/dual_encoder_dataloader.py
+145
-0
official/nlp/data/dual_encoder_dataloader_test.py
official/nlp/data/dual_encoder_dataloader_test.py
+131
-0
official/nlp/tasks/dual_encoder.py
official/nlp/tasks/dual_encoder.py
+203
-0
official/nlp/tasks/dual_encoder_test.py
official/nlp/tasks/dual_encoder_test.py
+126
-0
No files found.
official/nlp/data/dual_encoder_dataloader.py
0 → 100644
View file @
5cc6df63
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Loads dataset for the dual encoder (retrieval) task."""
import
functools
import
itertools
from
typing
import
Iterable
,
Mapping
,
Optional
,
Tuple
import
dataclasses
import
tensorflow
as
tf
import
tensorflow_hub
as
hub
from
official.core
import
config_definitions
as
cfg
from
official.core
import
input_reader
from
official.nlp.data
import
data_loader
from
official.nlp.data
import
data_loader_factory
from
official.nlp.modeling
import
layers
@
dataclasses
.
dataclass
class
DualEncoderDataConfig
(
cfg
.
DataConfig
):
"""Data config for dual encoder task (tasks/dual_encoder)."""
# Either set `input_path`...
input_path
:
str
=
''
# ...or `tfds_name` and `tfds_split` to specify input.
tfds_name
:
str
=
''
tfds_split
:
str
=
''
global_batch_size
:
int
=
32
# Either build preprocessing with Python code by specifying these values...
vocab_file
:
str
=
''
lower_case
:
bool
=
True
# ...or load preprocessing from a SavedModel at this location.
preprocessing_hub_module_url
:
str
=
''
left_text_fields
:
Tuple
[
str
]
=
(
'left_input'
,)
right_text_fields
:
Tuple
[
str
]
=
(
'right_input'
,)
is_training
:
bool
=
True
seq_length
:
int
=
128
@
data_loader_factory
.
register_data_loader_cls
(
DualEncoderDataConfig
)
class
DualEncoderDataLoader
(
data_loader
.
DataLoader
):
"""A class to load dataset for dual encoder task (tasks/dual_encoder)."""
def
__init__
(
self
,
params
):
if
bool
(
params
.
tfds_name
)
==
bool
(
params
.
input_path
):
raise
ValueError
(
'Must specify either `tfds_name` and `tfds_split` '
'or `input_path`.'
)
if
bool
(
params
.
vocab_file
)
==
bool
(
params
.
preprocessing_hub_module_url
):
raise
ValueError
(
'Must specify exactly one of vocab_file (with matching '
'lower_case flag) or preprocessing_hub_module_url.'
)
self
.
_params
=
params
self
.
_seq_length
=
params
.
seq_length
self
.
_left_text_fields
=
params
.
left_text_fields
self
.
_right_text_fields
=
params
.
right_text_fields
if
params
.
preprocessing_hub_module_url
:
preprocessing_hub_module
=
hub
.
load
(
params
.
preprocessing_hub_module_url
)
self
.
_tokenizer
=
preprocessing_hub_module
.
tokenize
self
.
_pack_inputs
=
functools
.
partial
(
preprocessing_hub_module
.
bert_pack_inputs
,
seq_length
=
params
.
seq_length
)
else
:
self
.
_tokenizer
=
layers
.
BertTokenizer
(
vocab_file
=
params
.
vocab_file
,
lower_case
=
params
.
lower_case
)
self
.
_pack_inputs
=
layers
.
BertPackInputs
(
seq_length
=
params
.
seq_length
,
special_tokens_dict
=
self
.
_tokenizer
.
get_special_tokens_dict
())
def
_decode
(
self
,
record
:
tf
.
Tensor
):
"""Decodes a serialized tf.Example."""
name_to_features
=
{
x
:
tf
.
io
.
FixedLenFeature
([],
tf
.
string
)
for
x
in
itertools
.
chain
(
*
[
self
.
_left_text_fields
,
self
.
_right_text_fields
])
}
example
=
tf
.
io
.
parse_single_example
(
record
,
name_to_features
)
# tf.Example only supports tf.int64, but the TPU only supports tf.int32.
# So cast all int64 to int32.
for
name
in
example
:
t
=
example
[
name
]
if
t
.
dtype
==
tf
.
int64
:
t
=
tf
.
cast
(
t
,
tf
.
int32
)
example
[
name
]
=
t
return
example
def
_bert_tokenize
(
self
,
record
:
Mapping
[
str
,
tf
.
Tensor
],
text_fields
:
Iterable
[
str
])
->
Tuple
[
tf
.
Tensor
,
tf
.
Tensor
,
tf
.
Tensor
]:
"""Tokenize the input in text_fields using BERT tokenizer.
Args:
record: A tfexample record contains the features.
text_fields: A list of fields to be tokenzied.
Returns:
The tokenized features in a tuple of (input_word_ids, input_mask,
input_type_ids).
"""
segments_text
=
[
record
[
x
]
for
x
in
text_fields
]
segments_tokens
=
[
self
.
_tokenizer
(
s
)
for
s
in
segments_text
]
segments
=
[
tf
.
cast
(
x
.
merge_dims
(
1
,
2
),
tf
.
int32
)
for
x
in
segments_tokens
]
return
self
.
_pack_inputs
(
segments
)
def
_bert_preprocess
(
self
,
record
:
Mapping
[
str
,
tf
.
Tensor
])
->
Mapping
[
str
,
tf
.
Tensor
]:
"""Perform the bert word piece tokenization for left and right inputs."""
def
_switch_prefix
(
string
,
old
,
new
):
if
string
.
startswith
(
old
):
return
new
+
string
[
len
(
old
):]
raise
ValueError
(
'Expected {} to start with {}'
.
format
(
string
,
old
))
def
_switch_key_prefix
(
d
,
old
,
new
):
return
{
_switch_prefix
(
key
,
old
,
new
):
value
for
key
,
value
in
d
.
items
()}
model_inputs
=
_switch_key_prefix
(
self
.
_bert_tokenize
(
record
,
self
.
_left_text_fields
),
'input_'
,
'left_'
)
model_inputs
.
update
(
_switch_key_prefix
(
self
.
_bert_tokenize
(
record
,
self
.
_right_text_fields
),
'input_'
,
'right_'
))
return
model_inputs
def
load
(
self
,
input_context
:
Optional
[
tf
.
distribute
.
InputContext
]
=
None
):
"""Returns a tf.dataset.Dataset."""
reader
=
input_reader
.
InputReader
(
params
=
self
.
_params
,
# Skip `decoder_fn` for tfds input.
decoder_fn
=
self
.
_decode
if
self
.
_params
.
input_path
else
None
,
dataset_fn
=
tf
.
data
.
TFRecordDataset
,
postprocess_fn
=
self
.
_bert_preprocess
)
return
reader
.
read
(
input_context
)
official/nlp/data/dual_encoder_dataloader_test.py
0 → 100644
View file @
5cc6df63
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for official.nlp.data.dual_encoder_dataloader."""
import
os
from
absl.testing
import
parameterized
import
tensorflow
as
tf
from
official.nlp.data
import
dual_encoder_dataloader
_LEFT_FEATURE_NAME
=
'left_input'
_RIGHT_FEATURE_NAME
=
'right_input'
def
_create_fake_dataset
(
output_path
):
"""Creates a fake dataset contains examples for training a dual encoder model.
The created dataset contains examples with two byteslist features keyed by
_LEFT_FEATURE_NAME and _RIGHT_FEATURE_NAME.
Args:
output_path: The output path of the fake dataset.
"""
def
create_str_feature
(
values
):
return
tf
.
train
.
Feature
(
bytes_list
=
tf
.
train
.
BytesList
(
value
=
values
))
with
tf
.
io
.
TFRecordWriter
(
output_path
)
as
writer
:
for
_
in
range
(
100
):
features
=
{}
features
[
_LEFT_FEATURE_NAME
]
=
create_str_feature
([
b
'hello world.'
])
features
[
_RIGHT_FEATURE_NAME
]
=
create_str_feature
([
b
'world hello.'
])
tf_example
=
tf
.
train
.
Example
(
features
=
tf
.
train
.
Features
(
feature
=
features
))
writer
.
write
(
tf_example
.
SerializeToString
())
def
_make_vocab_file
(
vocab
,
output_path
):
with
tf
.
io
.
gfile
.
GFile
(
output_path
,
'w'
)
as
f
:
f
.
write
(
'
\n
'
.
join
(
vocab
+
[
''
]))
class
DualEncoderDataTest
(
tf
.
test
.
TestCase
,
parameterized
.
TestCase
):
def
test_load_dataset
(
self
):
seq_length
=
16
batch_size
=
10
train_data_path
=
os
.
path
.
join
(
self
.
get_temp_dir
(),
'train.tf_record'
)
vocab_path
=
os
.
path
.
join
(
self
.
get_temp_dir
(),
'vocab.txt'
)
_create_fake_dataset
(
train_data_path
)
_make_vocab_file
(
[
'[PAD]'
,
'[UNK]'
,
'[CLS]'
,
'[SEP]'
,
'he'
,
'#llo'
,
'world'
],
vocab_path
)
data_config
=
dual_encoder_dataloader
.
DualEncoderDataConfig
(
input_path
=
train_data_path
,
seq_length
=
seq_length
,
vocab_file
=
vocab_path
,
lower_case
=
True
,
left_text_fields
=
(
_LEFT_FEATURE_NAME
,),
right_text_fields
=
(
_RIGHT_FEATURE_NAME
,),
global_batch_size
=
batch_size
)
dataset
=
dual_encoder_dataloader
.
DualEncoderDataLoader
(
data_config
).
load
()
features
=
next
(
iter
(
dataset
))
self
.
assertCountEqual
(
[
'left_word_ids'
,
'left_mask'
,
'left_type_ids'
,
'right_word_ids'
,
'right_mask'
,
'right_type_ids'
],
features
.
keys
())
self
.
assertEqual
(
features
[
'left_word_ids'
].
shape
,
(
batch_size
,
seq_length
))
self
.
assertEqual
(
features
[
'left_mask'
].
shape
,
(
batch_size
,
seq_length
))
self
.
assertEqual
(
features
[
'left_type_ids'
].
shape
,
(
batch_size
,
seq_length
))
self
.
assertEqual
(
features
[
'right_word_ids'
].
shape
,
(
batch_size
,
seq_length
))
self
.
assertEqual
(
features
[
'right_mask'
].
shape
,
(
batch_size
,
seq_length
))
self
.
assertEqual
(
features
[
'right_type_ids'
].
shape
,
(
batch_size
,
seq_length
))
@
parameterized
.
parameters
(
False
,
True
)
def
test_load_tfds
(
self
,
use_preprocessing_hub
):
seq_length
=
16
batch_size
=
10
if
use_preprocessing_hub
:
vocab_path
=
''
preprocessing_hub
=
(
'https://tfhub.dev/tensorflow/bert_multi_cased_preprocess/3'
)
else
:
vocab_path
=
os
.
path
.
join
(
self
.
get_temp_dir
(),
'vocab.txt'
)
_make_vocab_file
(
[
'[PAD]'
,
'[UNK]'
,
'[CLS]'
,
'[SEP]'
,
'he'
,
'#llo'
,
'world'
],
vocab_path
)
preprocessing_hub
=
''
data_config
=
dual_encoder_dataloader
.
DualEncoderDataConfig
(
tfds_name
=
'para_crawl/enmt'
,
tfds_split
=
'train'
,
seq_length
=
seq_length
,
vocab_file
=
vocab_path
,
lower_case
=
True
,
left_text_fields
=
(
'en'
,),
right_text_fields
=
(
'mt'
,),
preprocessing_hub_module_url
=
preprocessing_hub
,
global_batch_size
=
batch_size
)
dataset
=
dual_encoder_dataloader
.
DualEncoderDataLoader
(
data_config
).
load
()
features
=
next
(
iter
(
dataset
))
self
.
assertCountEqual
(
[
'left_word_ids'
,
'left_mask'
,
'left_type_ids'
,
'right_word_ids'
,
'right_mask'
,
'right_type_ids'
],
features
.
keys
())
self
.
assertEqual
(
features
[
'left_word_ids'
].
shape
,
(
batch_size
,
seq_length
))
self
.
assertEqual
(
features
[
'left_mask'
].
shape
,
(
batch_size
,
seq_length
))
self
.
assertEqual
(
features
[
'left_type_ids'
].
shape
,
(
batch_size
,
seq_length
))
self
.
assertEqual
(
features
[
'right_word_ids'
].
shape
,
(
batch_size
,
seq_length
))
self
.
assertEqual
(
features
[
'right_mask'
].
shape
,
(
batch_size
,
seq_length
))
self
.
assertEqual
(
features
[
'right_type_ids'
].
shape
,
(
batch_size
,
seq_length
))
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
official/nlp/tasks/dual_encoder.py
0 → 100644
View file @
5cc6df63
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Dual encoder (retrieval) task."""
from
typing
import
Mapping
,
Tuple
# Import libraries
from
absl
import
logging
import
dataclasses
import
tensorflow
as
tf
from
official.core
import
base_task
from
official.core
import
config_definitions
as
cfg
from
official.core
import
task_factory
from
official.modeling
import
tf_utils
from
official.modeling.hyperparams
import
base_config
from
official.nlp.configs
import
encoders
from
official.nlp.data
import
data_loader_factory
from
official.nlp.modeling
import
models
from
official.nlp.tasks
import
utils
@
dataclasses
.
dataclass
class
ModelConfig
(
base_config
.
Config
):
"""A dual encoder (retrieval) configuration."""
# Normalize input embeddings if set to True.
normalize
:
bool
=
True
# Maximum input sequence length.
max_sequence_length
:
int
=
64
# Parameters for training a dual encoder model with additive margin, see
# https://www.ijcai.org/Proceedings/2019/0746.pdf for more details.
logit_scale
:
float
=
1
logit_margin
:
float
=
0
bidirectional
:
bool
=
False
# Defining k for calculating metrics recall@k.
eval_top_k
:
Tuple
[
int
,
...]
=
(
1
,
3
,
10
)
encoder
:
encoders
.
EncoderConfig
=
(
encoders
.
EncoderConfig
())
@
dataclasses
.
dataclass
class
DualEncoderConfig
(
cfg
.
TaskConfig
):
"""The model config."""
# At most one of `init_checkpoint` and `hub_module_url` can
# be specified.
init_checkpoint
:
str
=
''
hub_module_url
:
str
=
''
# Defines the concrete model config at instantiation time.
model
:
ModelConfig
=
ModelConfig
()
train_data
:
cfg
.
DataConfig
=
cfg
.
DataConfig
()
validation_data
:
cfg
.
DataConfig
=
cfg
.
DataConfig
()
@
task_factory
.
register_task_cls
(
DualEncoderConfig
)
class
DualEncoderTask
(
base_task
.
Task
):
"""Task object for dual encoder."""
def
build_model
(
self
):
"""Interface to build model. Refer to base_task.Task.build_model."""
if
self
.
task_config
.
hub_module_url
and
self
.
task_config
.
init_checkpoint
:
raise
ValueError
(
'At most one of `hub_module_url` and '
'`init_checkpoint` can be specified.'
)
if
self
.
task_config
.
hub_module_url
:
encoder_network
=
utils
.
get_encoder_from_hub
(
self
.
task_config
.
hub_module_url
)
else
:
encoder_network
=
encoders
.
build_encoder
(
self
.
task_config
.
model
.
encoder
)
# Currently, we only supports bert-style dual encoder.
return
models
.
DualEncoder
(
network
=
encoder_network
,
max_seq_length
=
self
.
task_config
.
model
.
max_sequence_length
,
normalize
=
self
.
task_config
.
model
.
normalize
,
logit_scale
=
self
.
task_config
.
model
.
logit_scale
,
logit_margin
=
self
.
task_config
.
model
.
logit_margin
,
output
=
'logits'
)
def
build_losses
(
self
,
labels
,
model_outputs
,
aux_losses
=
None
)
->
tf
.
Tensor
:
"""Interface to compute losses. Refer to base_task.Task.build_losses."""
del
labels
left_logits
=
model_outputs
[
'left_logits'
]
right_logits
=
model_outputs
[
'right_logits'
]
batch_size
=
tf_utils
.
get_shape_list
(
left_logits
,
name
=
'batch_size'
)[
0
]
ranking_labels
=
tf
.
range
(
batch_size
)
loss
=
tf_utils
.
safe_mean
(
tf
.
nn
.
sparse_softmax_cross_entropy_with_logits
(
labels
=
ranking_labels
,
logits
=
left_logits
))
if
self
.
task_config
.
model
.
bidirectional
:
right_rank_loss
=
tf_utils
.
safe_mean
(
tf
.
nn
.
sparse_softmax_cross_entropy_with_logits
(
labels
=
ranking_labels
,
logits
=
right_logits
))
loss
+=
right_rank_loss
return
tf
.
reduce_mean
(
loss
)
def
build_inputs
(
self
,
params
,
input_context
=
None
)
->
tf
.
data
.
Dataset
:
"""Returns tf.data.Dataset for sentence_prediction task."""
if
params
.
input_path
!=
'dummy'
:
return
data_loader_factory
.
get_data_loader
(
params
).
load
(
input_context
)
def
dummy_data
(
_
):
dummy_ids
=
tf
.
zeros
((
10
,
params
.
seq_length
),
dtype
=
tf
.
int32
)
x
=
dict
(
left_word_ids
=
dummy_ids
,
left_mask
=
dummy_ids
,
left_type_ids
=
dummy_ids
,
right_word_ids
=
dummy_ids
,
right_mask
=
dummy_ids
,
right_type_ids
=
dummy_ids
)
return
x
dataset
=
tf
.
data
.
Dataset
.
range
(
1
)
dataset
=
dataset
.
repeat
()
dataset
=
dataset
.
map
(
dummy_data
,
num_parallel_calls
=
tf
.
data
.
experimental
.
AUTOTUNE
)
return
dataset
def
build_metrics
(
self
,
training
=
None
):
del
training
metrics
=
[
tf
.
keras
.
metrics
.
Mean
(
name
=
'batch_size_per_core'
)]
for
k
in
self
.
task_config
.
model
.
eval_top_k
:
metrics
.
append
(
tf
.
keras
.
metrics
.
SparseTopKCategoricalAccuracy
(
k
=
k
,
name
=
f
'left_recall_at_
{
k
}
'
))
if
self
.
task_config
.
model
.
bidirectional
:
metrics
.
append
(
tf
.
keras
.
metrics
.
SparseTopKCategoricalAccuracy
(
k
=
k
,
name
=
f
'right_recall_at_
{
k
}
'
))
return
metrics
def
process_metrics
(
self
,
metrics
,
labels
,
model_outputs
):
del
labels
metrics
=
dict
([(
metric
.
name
,
metric
)
for
metric
in
metrics
])
left_logits
=
model_outputs
[
'left_logits'
]
right_logits
=
model_outputs
[
'right_logits'
]
batch_size
=
tf_utils
.
get_shape_list
(
left_logits
,
name
=
'sequence_output_tensor'
)[
0
]
ranking_labels
=
tf
.
range
(
batch_size
)
for
k
in
self
.
task_config
.
model
.
eval_top_k
:
metrics
[
f
'left_recall_at_
{
k
}
'
].
update_state
(
ranking_labels
,
left_logits
)
if
self
.
task_config
.
model
.
bidirectional
:
metrics
[
f
'right_recall_at_
{
k
}
'
].
update_state
(
ranking_labels
,
right_logits
)
metrics
[
'batch_size_per_core'
].
update_state
(
batch_size
)
def
validation_step
(
self
,
inputs
,
model
:
tf
.
keras
.
Model
,
metrics
=
None
)
->
Mapping
[
str
,
tf
.
Tensor
]:
outputs
=
model
(
inputs
)
loss
=
self
.
build_losses
(
labels
=
None
,
model_outputs
=
outputs
,
aux_losses
=
model
.
losses
)
logs
=
{
self
.
loss
:
loss
}
if
metrics
:
self
.
process_metrics
(
metrics
,
None
,
outputs
)
logs
.
update
({
m
.
name
:
m
.
result
()
for
m
in
metrics
})
elif
model
.
compiled_metrics
:
self
.
process_compiled_metrics
(
model
.
compiled_metrics
,
None
,
outputs
)
logs
.
update
({
m
.
name
:
m
.
result
()
for
m
in
model
.
metrics
})
return
logs
def
initialize
(
self
,
model
):
"""Load a pretrained checkpoint (if exists) and then train from iter 0."""
ckpt_dir_or_file
=
self
.
task_config
.
init_checkpoint
if
tf
.
io
.
gfile
.
isdir
(
ckpt_dir_or_file
):
ckpt_dir_or_file
=
tf
.
train
.
latest_checkpoint
(
ckpt_dir_or_file
)
if
not
ckpt_dir_or_file
:
return
pretrain2finetune_mapping
=
{
'encoder'
:
model
.
checkpoint_items
[
'encoder'
],
}
ckpt
=
tf
.
train
.
Checkpoint
(
**
pretrain2finetune_mapping
)
status
=
ckpt
.
read
(
ckpt_dir_or_file
)
status
.
expect_partial
().
assert_existing_objects_matched
()
logging
.
info
(
'Finished loading pretrained checkpoint from %s'
,
ckpt_dir_or_file
)
official/nlp/tasks/dual_encoder_test.py
0 → 100644
View file @
5cc6df63
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for official.nlp.tasks.sentence_prediction."""
import
functools
import
os
from
absl.testing
import
parameterized
import
tensorflow
as
tf
from
official.nlp.bert
import
configs
from
official.nlp.configs
import
bert
from
official.nlp.configs
import
encoders
from
official.nlp.data
import
dual_encoder_dataloader
from
official.nlp.tasks
import
dual_encoder
from
official.nlp.tasks
import
masked_lm
from
official.nlp.tools
import
export_tfhub_lib
class
DualEncoderTaskTest
(
tf
.
test
.
TestCase
,
parameterized
.
TestCase
):
def
setUp
(
self
):
super
(
DualEncoderTaskTest
,
self
).
setUp
()
self
.
_train_data_config
=
(
dual_encoder_dataloader
.
DualEncoderDataConfig
(
input_path
=
"dummy"
,
seq_length
=
32
))
def
get_model_config
(
self
):
return
dual_encoder
.
ModelConfig
(
max_sequence_length
=
32
,
encoder
=
encoders
.
EncoderConfig
(
bert
=
encoders
.
BertEncoderConfig
(
vocab_size
=
30522
,
num_layers
=
1
)))
def
_run_task
(
self
,
config
):
task
=
dual_encoder
.
DualEncoderTask
(
config
)
model
=
task
.
build_model
()
metrics
=
task
.
build_metrics
()
strategy
=
tf
.
distribute
.
get_strategy
()
dataset
=
strategy
.
distribute_datasets_from_function
(
functools
.
partial
(
task
.
build_inputs
,
config
.
train_data
))
dataset
.
batch
(
10
)
iterator
=
iter
(
dataset
)
optimizer
=
tf
.
keras
.
optimizers
.
SGD
(
lr
=
0.1
)
task
.
train_step
(
next
(
iterator
),
model
,
optimizer
,
metrics
=
metrics
)
task
.
validation_step
(
next
(
iterator
),
model
,
metrics
=
metrics
)
model
.
save
(
os
.
path
.
join
(
self
.
get_temp_dir
(),
"saved_model"
))
def
test_task
(
self
):
config
=
dual_encoder
.
DualEncoderConfig
(
init_checkpoint
=
self
.
get_temp_dir
(),
model
=
self
.
get_model_config
(),
train_data
=
self
.
_train_data_config
)
task
=
dual_encoder
.
DualEncoderTask
(
config
)
model
=
task
.
build_model
()
metrics
=
task
.
build_metrics
()
dataset
=
task
.
build_inputs
(
config
.
train_data
)
iterator
=
iter
(
dataset
)
optimizer
=
tf
.
keras
.
optimizers
.
SGD
(
lr
=
0.1
)
task
.
train_step
(
next
(
iterator
),
model
,
optimizer
,
metrics
=
metrics
)
task
.
validation_step
(
next
(
iterator
),
model
,
metrics
=
metrics
)
# Saves a checkpoint.
pretrain_cfg
=
bert
.
PretrainerConfig
(
encoder
=
encoders
.
EncoderConfig
(
bert
=
encoders
.
BertEncoderConfig
(
vocab_size
=
30522
,
num_layers
=
1
)))
pretrain_model
=
masked_lm
.
MaskedLMTask
(
None
).
build_model
(
pretrain_cfg
)
ckpt
=
tf
.
train
.
Checkpoint
(
model
=
pretrain_model
,
**
pretrain_model
.
checkpoint_items
)
ckpt
.
save
(
config
.
init_checkpoint
)
task
.
initialize
(
model
)
def
_export_bert_tfhub
(
self
):
bert_config
=
configs
.
BertConfig
(
vocab_size
=
30522
,
hidden_size
=
16
,
intermediate_size
=
32
,
max_position_embeddings
=
128
,
num_attention_heads
=
2
,
num_hidden_layers
=
4
)
encoder
=
export_tfhub_lib
.
get_bert_encoder
(
bert_config
)
model_checkpoint_dir
=
os
.
path
.
join
(
self
.
get_temp_dir
(),
"checkpoint"
)
checkpoint
=
tf
.
train
.
Checkpoint
(
encoder
=
encoder
)
checkpoint
.
save
(
os
.
path
.
join
(
model_checkpoint_dir
,
"test"
))
model_checkpoint_path
=
tf
.
train
.
latest_checkpoint
(
model_checkpoint_dir
)
vocab_file
=
os
.
path
.
join
(
self
.
get_temp_dir
(),
"uncased_vocab.txt"
)
with
tf
.
io
.
gfile
.
GFile
(
vocab_file
,
"w"
)
as
f
:
f
.
write
(
"dummy content"
)
export_path
=
os
.
path
.
join
(
self
.
get_temp_dir
(),
"hub"
)
export_tfhub_lib
.
export_model
(
export_path
,
bert_config
=
bert_config
,
encoder_config
=
None
,
model_checkpoint_path
=
model_checkpoint_path
,
vocab_file
=
vocab_file
,
do_lower_case
=
True
,
with_mlm
=
False
)
return
export_path
def
test_task_with_hub
(
self
):
hub_module_url
=
self
.
_export_bert_tfhub
()
config
=
dual_encoder
.
DualEncoderConfig
(
hub_module_url
=
hub_module_url
,
model
=
self
.
get_model_config
(),
train_data
=
self
.
_train_data_config
)
self
.
_run_task
(
config
)
if
__name__
==
"__main__"
:
tf
.
test
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment