Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
ca552843
Unverified
Commit
ca552843
authored
Sep 16, 2021
by
Srihari Humbarwadi
Committed by
GitHub
Sep 16, 2021
Browse files
Merge branch 'panoptic-segmentation' into panoptic-segmentation
parents
7e2f7a35
6b90e134
Changes
283
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1857 additions
and
40 deletions
+1857
-40
official/nlp/projects/teams/teams.py
official/nlp/projects/teams/teams.py
+104
-0
official/nlp/projects/teams/teams_experiments.py
official/nlp/projects/teams/teams_experiments.py
+64
-0
official/nlp/projects/teams/teams_experiments_test.py
official/nlp/projects/teams/teams_experiments_test.py
+38
-0
official/nlp/projects/teams/teams_pretrainer.py
official/nlp/projects/teams/teams_pretrainer.py
+460
-0
official/nlp/projects/teams/teams_pretrainer_test.py
official/nlp/projects/teams/teams_pretrainer_test.py
+188
-0
official/nlp/projects/teams/teams_task.py
official/nlp/projects/teams/teams_task.py
+250
-0
official/nlp/projects/teams/teams_task_test.py
official/nlp/projects/teams/teams_task_test.py
+56
-0
official/nlp/projects/triviaqa/inputs.py
official/nlp/projects/triviaqa/inputs.py
+5
-5
official/nlp/serving/export_savedmodel_util.py
official/nlp/serving/export_savedmodel_util.py
+40
-1
official/nlp/serving/serving_modules.py
official/nlp/serving/serving_modules.py
+24
-6
official/nlp/tasks/sentence_prediction.py
official/nlp/tasks/sentence_prediction.py
+14
-3
official/nlp/tasks/translation.py
official/nlp/tasks/translation.py
+1
-1
official/nlp/tasks/translation_test.py
official/nlp/tasks/translation_test.py
+8
-4
official/nlp/train.py
official/nlp/train.py
+32
-20
official/pip_package/setup.py
official/pip_package/setup.py
+1
-0
official/projects/basnet/README.md
official/projects/basnet/README.md
+35
-0
official/projects/basnet/configs/basnet.py
official/projects/basnet/configs/basnet.py
+156
-0
official/projects/basnet/configs/basnet_test.py
official/projects/basnet/configs/basnet_test.py
+42
-0
official/projects/basnet/configs/experiments/basnet_dut_gpu.yaml
...l/projects/basnet/configs/experiments/basnet_dut_gpu.yaml
+10
-0
official/projects/basnet/evaluation/metrics.py
official/projects/basnet/evaluation/metrics.py
+329
-0
No files found.
official/nlp/projects/teams/teams.py
0 → 100644
View file @
ca552843
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""TEAMS model configurations and instantiation methods."""
import
dataclasses
import
gin
import
tensorflow
as
tf
from
official.modeling
import
tf_utils
from
official.modeling.hyperparams
import
base_config
from
official.nlp.configs
import
encoders
from
official.nlp.modeling
import
layers
from
official.nlp.modeling
import
networks
@
dataclasses
.
dataclass
class
TeamsPretrainerConfig
(
base_config
.
Config
):
"""Teams pretrainer configuration."""
# Candidate size for multi-word selection task, including the correct word.
candidate_size
:
int
=
5
# Weight for the generator masked language model task.
generator_loss_weight
:
float
=
1.0
# Weight for the replaced token detection task.
discriminator_rtd_loss_weight
:
float
=
5.0
# Weight for the multi-word selection task.
discriminator_mws_loss_weight
:
float
=
2.0
# Whether share embedding network between generator and discriminator.
tie_embeddings
:
bool
=
True
# Number of bottom layers shared between generator and discriminator.
# Non-positive value implies no sharing.
num_shared_generator_hidden_layers
:
int
=
3
# Number of bottom layers shared between different discriminator tasks.
num_discriminator_task_agnostic_layers
:
int
=
11
generator
:
encoders
.
BertEncoderConfig
=
encoders
.
BertEncoderConfig
()
discriminator
:
encoders
.
BertEncoderConfig
=
encoders
.
BertEncoderConfig
()
# Used for compatibility with continuous finetuning where common BERT config
# is used.
encoder
:
encoders
.
EncoderConfig
=
encoders
.
EncoderConfig
()
@
gin
.
configurable
def
get_encoder
(
bert_config
,
embedding_network
=
None
,
hidden_layers
=
layers
.
Transformer
):
"""Gets a 'EncoderScaffold' object.
Args:
bert_config: A 'modeling.BertConfig'.
embedding_network: Embedding network instance.
hidden_layers: List of hidden layer instances.
Returns:
A encoder object.
"""
# embedding_size is required for PackedSequenceEmbedding.
if
bert_config
.
embedding_size
is
None
:
bert_config
.
embedding_size
=
bert_config
.
hidden_size
embedding_cfg
=
dict
(
vocab_size
=
bert_config
.
vocab_size
,
type_vocab_size
=
bert_config
.
type_vocab_size
,
hidden_size
=
bert_config
.
hidden_size
,
embedding_width
=
bert_config
.
embedding_size
,
max_seq_length
=
bert_config
.
max_position_embeddings
,
initializer
=
tf
.
keras
.
initializers
.
TruncatedNormal
(
stddev
=
bert_config
.
initializer_range
),
dropout_rate
=
bert_config
.
dropout_rate
,
)
hidden_cfg
=
dict
(
num_attention_heads
=
bert_config
.
num_attention_heads
,
intermediate_size
=
bert_config
.
intermediate_size
,
intermediate_activation
=
tf_utils
.
get_activation
(
bert_config
.
hidden_activation
),
dropout_rate
=
bert_config
.
dropout_rate
,
attention_dropout_rate
=
bert_config
.
attention_dropout_rate
,
kernel_initializer
=
tf
.
keras
.
initializers
.
TruncatedNormal
(
stddev
=
bert_config
.
initializer_range
),
)
if
embedding_network
is
None
:
embedding_network
=
networks
.
PackedSequenceEmbedding
(
**
embedding_cfg
)
kwargs
=
dict
(
embedding_cfg
=
embedding_cfg
,
embedding_cls
=
embedding_network
,
hidden_cls
=
hidden_layers
,
hidden_cfg
=
hidden_cfg
,
num_hidden_instances
=
bert_config
.
num_layers
,
pooled_output_dim
=
bert_config
.
hidden_size
,
pooler_layer_initializer
=
tf
.
keras
.
initializers
.
TruncatedNormal
(
stddev
=
bert_config
.
initializer_range
),
dict_outputs
=
True
)
# Relies on gin configuration to define the Transformer encoder arguments.
return
networks
.
encoder_scaffold
.
EncoderScaffold
(
**
kwargs
)
official/nlp/projects/teams/teams_experiments.py
0 → 100644
View file @
ca552843
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
# pylint: disable=g-doc-return-or-yield,line-too-long
"""TEAMS experiments."""
import
dataclasses
from
official.core
import
config_definitions
as
cfg
from
official.core
import
exp_factory
from
official.modeling
import
optimization
from
official.nlp.data
import
pretrain_dataloader
from
official.nlp.projects.teams
import
teams_task
AdamWeightDecay
=
optimization
.
AdamWeightDecayConfig
PolynomialLr
=
optimization
.
PolynomialLrConfig
PolynomialWarmupConfig
=
optimization
.
PolynomialWarmupConfig
@
dataclasses
.
dataclass
class
TeamsOptimizationConfig
(
optimization
.
OptimizationConfig
):
"""TEAMS optimization config."""
optimizer
:
optimization
.
OptimizerConfig
=
optimization
.
OptimizerConfig
(
type
=
"adamw"
,
adamw
=
AdamWeightDecay
(
weight_decay_rate
=
0.01
,
exclude_from_weight_decay
=
[
"LayerNorm"
,
"layer_norm"
,
"bias"
],
epsilon
=
1e-6
))
learning_rate
:
optimization
.
LrConfig
=
optimization
.
LrConfig
(
type
=
"polynomial"
,
polynomial
=
PolynomialLr
(
initial_learning_rate
=
1e-4
,
decay_steps
=
1000000
,
end_learning_rate
=
0.0
))
warmup
:
optimization
.
WarmupConfig
=
optimization
.
WarmupConfig
(
type
=
"polynomial"
,
polynomial
=
PolynomialWarmupConfig
(
warmup_steps
=
10000
))
@
exp_factory
.
register_config_factory
(
"teams/pretraining"
)
def
teams_pretrain
()
->
cfg
.
ExperimentConfig
:
"""TEAMS pretraining."""
config
=
cfg
.
ExperimentConfig
(
task
=
teams_task
.
TeamsPretrainTaskConfig
(
train_data
=
pretrain_dataloader
.
BertPretrainDataConfig
(),
validation_data
=
pretrain_dataloader
.
BertPretrainDataConfig
(
is_training
=
False
)),
trainer
=
cfg
.
TrainerConfig
(
optimizer_config
=
TeamsOptimizationConfig
(),
train_steps
=
1000000
),
restrictions
=
[
"task.train_data.is_training != None"
,
"task.validation_data.is_training != None"
])
return
config
official/nlp/projects/teams/teams_experiments_test.py
0 → 100644
View file @
ca552843
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Tests for teams_experiments."""
from
absl.testing
import
parameterized
import
tensorflow
as
tf
# pylint: disable=unused-import
from
official.common
import
registry_imports
# pylint: enable=unused-import
from
official.core
import
config_definitions
as
cfg
from
official.core
import
exp_factory
class
TeamsExperimentsTest
(
tf
.
test
.
TestCase
,
parameterized
.
TestCase
):
@
parameterized
.
parameters
((
'teams/pretraining'
,))
def
test_teams_experiments
(
self
,
config_name
):
config
=
exp_factory
.
get_exp_config
(
config_name
)
self
.
assertIsInstance
(
config
,
cfg
.
ExperimentConfig
)
self
.
assertIsInstance
(
config
.
task
.
train_data
,
cfg
.
DataConfig
)
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
official/nlp/projects/teams/teams_pretrainer.py
0 → 100644
View file @
ca552843
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Trainer network for TEAMS models."""
# pylint: disable=g-classes-have-attributes
import
tensorflow
as
tf
from
official.modeling
import
tf_utils
from
official.nlp.modeling
import
layers
from
official.nlp.modeling
import
models
class
ReplacedTokenDetectionHead
(
tf
.
keras
.
layers
.
Layer
):
"""Replaced token detection discriminator head.
Arguments:
encoder_cfg: Encoder config, used to create hidden layers and head.
num_task_agnostic_layers: Number of task agnostic layers in the
discriminator.
output: The output style for this network. Can be either 'logits' or
'predictions'.
"""
def
__init__
(
self
,
encoder_cfg
,
num_task_agnostic_layers
,
output
=
'logits'
,
name
=
'rtd'
,
**
kwargs
):
super
(
ReplacedTokenDetectionHead
,
self
).
__init__
(
name
=
name
,
**
kwargs
)
self
.
num_task_agnostic_layers
=
num_task_agnostic_layers
self
.
hidden_size
=
encoder_cfg
[
'embedding_cfg'
][
'hidden_size'
]
self
.
num_hidden_instances
=
encoder_cfg
[
'num_hidden_instances'
]
self
.
hidden_cfg
=
encoder_cfg
[
'hidden_cfg'
]
self
.
activation
=
self
.
hidden_cfg
[
'intermediate_activation'
]
self
.
initializer
=
self
.
hidden_cfg
[
'kernel_initializer'
]
self
.
hidden_layers
=
[]
for
i
in
range
(
self
.
num_task_agnostic_layers
,
self
.
num_hidden_instances
):
self
.
hidden_layers
.
append
(
layers
.
Transformer
(
num_attention_heads
=
self
.
hidden_cfg
[
'num_attention_heads'
],
intermediate_size
=
self
.
hidden_cfg
[
'intermediate_size'
],
intermediate_activation
=
self
.
activation
,
dropout_rate
=
self
.
hidden_cfg
[
'dropout_rate'
],
attention_dropout_rate
=
self
.
hidden_cfg
[
'attention_dropout_rate'
],
kernel_initializer
=
self
.
initializer
,
name
=
'transformer/layer_%d_rtd'
%
i
))
self
.
dense
=
tf
.
keras
.
layers
.
Dense
(
self
.
hidden_size
,
activation
=
self
.
activation
,
kernel_initializer
=
self
.
initializer
,
name
=
'transform/rtd_dense'
)
self
.
rtd_head
=
tf
.
keras
.
layers
.
Dense
(
units
=
1
,
kernel_initializer
=
self
.
initializer
,
name
=
'transform/rtd_head'
)
if
output
not
in
(
'predictions'
,
'logits'
):
raise
ValueError
(
(
'Unknown `output` value "%s". `output` can be either "logits" or '
'"predictions"'
)
%
output
)
self
.
_output_type
=
output
def
call
(
self
,
sequence_data
,
input_mask
):
"""Compute inner-products of hidden vectors with sampled element embeddings.
Args:
sequence_data: A [batch_size, seq_length, num_hidden] tensor.
input_mask: A [batch_size, seq_length] binary mask to separate the input
from the padding.
Returns:
A [batch_size, seq_length] tensor.
"""
attention_mask
=
layers
.
SelfAttentionMask
()([
sequence_data
,
input_mask
])
data
=
sequence_data
for
hidden_layer
in
self
.
hidden_layers
:
data
=
hidden_layer
([
sequence_data
,
attention_mask
])
rtd_logits
=
self
.
rtd_head
(
self
.
dense
(
data
))
return
tf
.
squeeze
(
rtd_logits
,
axis
=-
1
)
class
MultiWordSelectionHead
(
tf
.
keras
.
layers
.
Layer
):
"""Multi-word selection discriminator head.
Arguments:
embedding_table: The embedding table.
activation: The activation, if any, for the dense layer.
initializer: The intializer for the dense layer. Defaults to a Glorot
uniform initializer.
output: The output style for this network. Can be either 'logits' or
'predictions'.
"""
def
__init__
(
self
,
embedding_table
,
activation
=
None
,
initializer
=
'glorot_uniform'
,
output
=
'logits'
,
name
=
'mws'
,
**
kwargs
):
super
(
MultiWordSelectionHead
,
self
).
__init__
(
name
=
name
,
**
kwargs
)
self
.
embedding_table
=
embedding_table
self
.
activation
=
activation
self
.
initializer
=
tf
.
keras
.
initializers
.
get
(
initializer
)
self
.
_vocab_size
,
self
.
embed_size
=
self
.
embedding_table
.
shape
self
.
dense
=
tf
.
keras
.
layers
.
Dense
(
self
.
embed_size
,
activation
=
self
.
activation
,
kernel_initializer
=
self
.
initializer
,
name
=
'transform/mws_dense'
)
self
.
layer_norm
=
tf
.
keras
.
layers
.
LayerNormalization
(
axis
=-
1
,
epsilon
=
1e-12
,
name
=
'transform/mws_layernorm'
)
if
output
not
in
(
'predictions'
,
'logits'
):
raise
ValueError
(
(
'Unknown `output` value "%s". `output` can be either "logits" or '
'"predictions"'
)
%
output
)
self
.
_output_type
=
output
def
call
(
self
,
sequence_data
,
masked_positions
,
candidate_sets
):
"""Compute inner-products of hidden vectors with sampled element embeddings.
Args:
sequence_data: A [batch_size, seq_length, num_hidden] tensor.
masked_positions: A [batch_size, num_prediction] tensor.
candidate_sets: A [batch_size, num_prediction, k] tensor.
Returns:
A [batch_size, num_prediction, k] tensor.
"""
# Gets shapes for later usage
candidate_set_shape
=
tf_utils
.
get_shape_list
(
candidate_sets
)
num_prediction
=
candidate_set_shape
[
1
]
# Gathers hidden vectors -> (batch_size, num_prediction, 1, embed_size)
masked_lm_input
=
self
.
_gather_indexes
(
sequence_data
,
masked_positions
)
lm_data
=
self
.
dense
(
masked_lm_input
)
lm_data
=
self
.
layer_norm
(
lm_data
)
lm_data
=
tf
.
expand_dims
(
tf
.
reshape
(
lm_data
,
[
-
1
,
num_prediction
,
self
.
embed_size
]),
2
)
# Gathers embeddings -> (batch_size, num_prediction, embed_size, k)
flat_candidate_sets
=
tf
.
reshape
(
candidate_sets
,
[
-
1
])
candidate_embeddings
=
tf
.
gather
(
self
.
embedding_table
,
flat_candidate_sets
)
candidate_embeddings
=
tf
.
reshape
(
candidate_embeddings
,
tf
.
concat
([
tf
.
shape
(
candidate_sets
),
[
self
.
embed_size
]],
axis
=
0
)
)
candidate_embeddings
.
set_shape
(
candidate_sets
.
shape
.
as_list
()
+
[
self
.
embed_size
])
candidate_embeddings
=
tf
.
transpose
(
candidate_embeddings
,
[
0
,
1
,
3
,
2
])
# matrix multiplication + squeeze -> (batch_size, num_prediction, k)
logits
=
tf
.
matmul
(
lm_data
,
candidate_embeddings
)
logits
=
tf
.
squeeze
(
logits
,
2
)
if
self
.
_output_type
==
'logits'
:
return
logits
return
tf
.
nn
.
log_softmax
(
logits
)
def
_gather_indexes
(
self
,
sequence_tensor
,
positions
):
"""Gathers the vectors at the specific positions.
Args:
sequence_tensor: Sequence output of shape
(`batch_size`, `seq_length`, `num_hidden`) where `num_hidden` is
number of hidden units.
positions: Positions ids of tokens in batched sequences.
Returns:
Sequence tensor of shape (batch_size * num_predictions,
num_hidden).
"""
sequence_shape
=
tf_utils
.
get_shape_list
(
sequence_tensor
,
name
=
'sequence_output_tensor'
)
batch_size
,
seq_length
,
width
=
sequence_shape
flat_offsets
=
tf
.
reshape
(
tf
.
range
(
0
,
batch_size
,
dtype
=
tf
.
int32
)
*
seq_length
,
[
-
1
,
1
])
flat_positions
=
tf
.
reshape
(
positions
+
flat_offsets
,
[
-
1
])
flat_sequence_tensor
=
tf
.
reshape
(
sequence_tensor
,
[
batch_size
*
seq_length
,
width
])
output_tensor
=
tf
.
gather
(
flat_sequence_tensor
,
flat_positions
)
return
output_tensor
@
tf
.
keras
.
utils
.
register_keras_serializable
(
package
=
'Text'
)
class
TeamsPretrainer
(
tf
.
keras
.
Model
):
"""TEAMS network training model.
This is an implementation of the network structure described in "Training
ELECTRA Augmented with Multi-word Selection"
(https://arxiv.org/abs/2106.00139).
The TeamsPretrainer allows a user to pass in two transformer encoders, one
for generator, the other for discriminator (multi-word selection). The
pretrainer then instantiates the masked language model (at generator side) and
classification networks (including both multi-word selection head and replaced
token detection head) that are used to create the training objectives.
*Note* that the model is constructed by Keras Subclass API, where layers are
defined inside `__init__` and `call()` implements the computation.
Args:
generator_network: A transformer encoder for generator, this network should
output a sequence output.
discriminator_mws_network: A transformer encoder for multi-word selection
discriminator, this network should output a sequence output.
num_discriminator_task_agnostic_layers: Number of layers shared between
multi-word selection and random token detection discriminators.
vocab_size: Size of generator output vocabulary
candidate_size: Candidate size for multi-word selection task,
including the correct word.
mlm_activation: The activation (if any) to use in the masked LM and
classification networks. If None, no activation will be used.
mlm_initializer: The initializer (if any) to use in the masked LM and
classification networks. Defaults to a Glorot uniform initializer.
output_type: The output style for this network. Can be either `logits` or
`predictions`.
"""
def
__init__
(
self
,
generator_network
,
discriminator_mws_network
,
num_discriminator_task_agnostic_layers
,
vocab_size
,
candidate_size
=
5
,
mlm_activation
=
None
,
mlm_initializer
=
'glorot_uniform'
,
output_type
=
'logits'
,
**
kwargs
):
super
().
__init__
()
self
.
_config
=
{
'generator_network'
:
generator_network
,
'discriminator_mws_network'
:
discriminator_mws_network
,
'num_discriminator_task_agnostic_layers'
:
num_discriminator_task_agnostic_layers
,
'vocab_size'
:
vocab_size
,
'candidate_size'
:
candidate_size
,
'mlm_activation'
:
mlm_activation
,
'mlm_initializer'
:
mlm_initializer
,
'output_type'
:
output_type
,
}
for
k
,
v
in
kwargs
.
items
():
self
.
_config
[
k
]
=
v
self
.
generator_network
=
generator_network
self
.
discriminator_mws_network
=
discriminator_mws_network
self
.
vocab_size
=
vocab_size
self
.
candidate_size
=
candidate_size
self
.
mlm_activation
=
mlm_activation
self
.
mlm_initializer
=
mlm_initializer
self
.
output_type
=
output_type
self
.
embedding_table
=
(
self
.
discriminator_mws_network
.
embedding_network
.
get_embedding_table
())
self
.
masked_lm
=
layers
.
MaskedLM
(
embedding_table
=
self
.
embedding_table
,
activation
=
mlm_activation
,
initializer
=
mlm_initializer
,
output
=
output_type
,
name
=
'generator_masked_lm'
)
discriminator_cfg
=
self
.
discriminator_mws_network
.
get_config
()
self
.
num_task_agnostic_layers
=
num_discriminator_task_agnostic_layers
self
.
discriminator_rtd_head
=
ReplacedTokenDetectionHead
(
encoder_cfg
=
discriminator_cfg
,
num_task_agnostic_layers
=
self
.
num_task_agnostic_layers
,
output
=
output_type
,
name
=
'discriminator_rtd'
)
hidden_cfg
=
discriminator_cfg
[
'hidden_cfg'
]
self
.
discriminator_mws_head
=
MultiWordSelectionHead
(
embedding_table
=
self
.
embedding_table
,
activation
=
hidden_cfg
[
'intermediate_activation'
],
initializer
=
hidden_cfg
[
'kernel_initializer'
],
output
=
output_type
,
name
=
'discriminator_mws'
)
def
call
(
self
,
inputs
):
"""TEAMS forward pass.
Args:
inputs: A dict of all inputs, same as the standard BERT model.
Returns:
outputs: A dict of pretrainer model outputs, including
(1) lm_outputs: A `[batch_size, num_token_predictions, vocab_size]`
tensor indicating logits on masked positions.
(2) disc_rtd_logits: A `[batch_size, sequence_length]` tensor indicating
logits for discriminator replaced token detection task.
(3) disc_rtd_label: A `[batch_size, sequence_length]` tensor indicating
target labels for discriminator replaced token detection task.
(4) disc_mws_logits: A `[batch_size, num_token_predictions,
candidate_size]` tensor indicating logits for discriminator multi-word
selection task.
(5) disc_mws_labels: A `[batch_size, num_token_predictions]` tensor
indicating target labels for discriminator multi-word selection task.
"""
input_word_ids
=
inputs
[
'input_word_ids'
]
input_mask
=
inputs
[
'input_mask'
]
input_type_ids
=
inputs
[
'input_type_ids'
]
masked_lm_positions
=
inputs
[
'masked_lm_positions'
]
# Runs generator.
sequence_output
=
self
.
generator_network
(
[
input_word_ids
,
input_mask
,
input_type_ids
])[
'sequence_output'
]
lm_outputs
=
self
.
masked_lm
(
sequence_output
,
masked_lm_positions
)
# Samples tokens from generator.
fake_data
=
self
.
_get_fake_data
(
inputs
,
lm_outputs
)
# Runs discriminator.
disc_input
=
fake_data
[
'inputs'
]
disc_rtd_label
=
fake_data
[
'is_fake_tokens'
]
disc_mws_candidates
=
fake_data
[
'candidate_set'
]
mws_sequence_outputs
=
self
.
discriminator_mws_network
([
disc_input
[
'input_word_ids'
],
disc_input
[
'input_mask'
],
disc_input
[
'input_type_ids'
]
])[
'encoder_outputs'
]
# Applies replaced token detection with input selected based on
# self.num_discriminator_task_agnostic_layers
disc_rtd_logits
=
self
.
discriminator_rtd_head
(
mws_sequence_outputs
[
self
.
num_task_agnostic_layers
-
1
],
input_mask
)
# Applies multi-word selection.
disc_mws_logits
=
self
.
discriminator_mws_head
(
mws_sequence_outputs
[
-
1
],
masked_lm_positions
,
disc_mws_candidates
)
disc_mws_label
=
tf
.
zeros_like
(
masked_lm_positions
,
dtype
=
tf
.
int32
)
outputs
=
{
'lm_outputs'
:
lm_outputs
,
'disc_rtd_logits'
:
disc_rtd_logits
,
'disc_rtd_label'
:
disc_rtd_label
,
'disc_mws_logits'
:
disc_mws_logits
,
'disc_mws_label'
:
disc_mws_label
,
}
return
outputs
def
_get_fake_data
(
self
,
inputs
,
mlm_logits
):
"""Generate corrupted data for discriminator.
Note it is poosible for sampled token to be the same as the correct one.
Args:
inputs: A dict of all inputs, same as the input of `call()` function
mlm_logits: The generator's output logits
Returns:
A dict of generated fake data
"""
inputs
=
models
.
electra_pretrainer
.
unmask
(
inputs
,
duplicate
=
True
)
# Samples replaced token.
sampled_tokens
=
tf
.
stop_gradient
(
models
.
electra_pretrainer
.
sample_from_softmax
(
mlm_logits
,
disallow
=
None
))
sampled_tokids
=
tf
.
argmax
(
sampled_tokens
,
axis
=-
1
,
output_type
=
tf
.
int32
)
# Prepares input and label for replaced token detection task.
updated_input_ids
,
masked
=
models
.
electra_pretrainer
.
scatter_update
(
inputs
[
'input_word_ids'
],
sampled_tokids
,
inputs
[
'masked_lm_positions'
])
rtd_labels
=
masked
*
(
1
-
tf
.
cast
(
tf
.
equal
(
updated_input_ids
,
inputs
[
'input_word_ids'
]),
tf
.
int32
))
updated_inputs
=
models
.
electra_pretrainer
.
get_updated_inputs
(
inputs
,
duplicate
=
True
,
input_word_ids
=
updated_input_ids
)
# Samples (candidate_size-1) negatives and concat with true tokens
disallow
=
tf
.
one_hot
(
inputs
[
'masked_lm_ids'
],
depth
=
self
.
vocab_size
,
dtype
=
tf
.
float32
)
sampled_candidates
=
tf
.
stop_gradient
(
sample_k_from_softmax
(
mlm_logits
,
k
=
self
.
candidate_size
-
1
,
disallow
=
disallow
))
true_token_id
=
tf
.
expand_dims
(
inputs
[
'masked_lm_ids'
],
-
1
)
candidate_set
=
tf
.
concat
([
true_token_id
,
sampled_candidates
],
-
1
)
return
{
'inputs'
:
updated_inputs
,
'is_fake_tokens'
:
rtd_labels
,
'sampled_tokens'
:
sampled_tokens
,
'candidate_set'
:
candidate_set
}
@
property
def
checkpoint_items
(
self
):
"""Returns a dictionary of items to be additionally checkpointed."""
items
=
dict
(
encoder
=
self
.
discriminator_mws_network
)
return
items
def
get_config
(
self
):
return
self
.
_config
@
classmethod
def
from_config
(
cls
,
config
,
custom_objects
=
None
):
return
cls
(
**
config
)
def
sample_k_from_softmax
(
logits
,
k
,
disallow
=
None
,
use_topk
=
False
):
"""Implement softmax sampling using gumbel softmax trick to select k items.
Args:
logits: A [batch_size, num_token_predictions, vocab_size] tensor indicating
the generator output logits for each masked position.
k: Number of samples
disallow: If `None`, we directly sample tokens from the logits. Otherwise,
this is a tensor of size [batch_size, num_token_predictions, vocab_size]
indicating the true word id in each masked position.
use_topk: Whether to use tf.nn.top_k or using iterative approach where the
latter is empirically faster.
Returns:
sampled_tokens: A [batch_size, num_token_predictions, k] tensor indicating
the sampled word id in each masked position.
"""
if
use_topk
:
if
disallow
is
not
None
:
logits
-=
10000.0
*
disallow
uniform_noise
=
tf
.
random
.
uniform
(
tf_utils
.
get_shape_list
(
logits
),
minval
=
0
,
maxval
=
1
)
gumbel_noise
=
-
tf
.
math
.
log
(
-
tf
.
math
.
log
(
uniform_noise
+
1e-9
)
+
1e-9
)
_
,
sampled_tokens
=
tf
.
nn
.
top_k
(
logits
+
gumbel_noise
,
k
=
k
,
sorted
=
False
)
else
:
sampled_tokens_list
=
[]
vocab_size
=
tf_utils
.
get_shape_list
(
logits
)[
-
1
]
if
disallow
is
not
None
:
logits
-=
10000.0
*
disallow
uniform_noise
=
tf
.
random
.
uniform
(
tf_utils
.
get_shape_list
(
logits
),
minval
=
0
,
maxval
=
1
)
gumbel_noise
=
-
tf
.
math
.
log
(
-
tf
.
math
.
log
(
uniform_noise
+
1e-9
)
+
1e-9
)
logits
+=
gumbel_noise
for
_
in
range
(
k
):
token_ids
=
tf
.
argmax
(
logits
,
-
1
,
output_type
=
tf
.
int32
)
sampled_tokens_list
.
append
(
token_ids
)
logits
-=
10000.0
*
tf
.
one_hot
(
token_ids
,
depth
=
vocab_size
,
dtype
=
tf
.
float32
)
sampled_tokens
=
tf
.
stack
(
sampled_tokens_list
,
-
1
)
return
sampled_tokens
official/nlp/projects/teams/teams_pretrainer_test.py
0 → 100644
View file @
ca552843
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for TEAMS pre trainer network."""
import
tensorflow
as
tf
from
tensorflow.python.keras
import
keras_parameterized
# pylint: disable=g-direct-tensorflow-import
from
official.modeling
import
activations
from
official.nlp.modeling.networks
import
encoder_scaffold
from
official.nlp.modeling.networks
import
packed_sequence_embedding
from
official.nlp.projects.teams
import
teams_pretrainer
# This decorator runs the test in V1, V2-Eager, and V2-Functional mode. It
# guarantees forward compatibility of this code for the V2 switchover.
@
keras_parameterized
.
run_all_keras_modes
class
TeamsPretrainerTest
(
keras_parameterized
.
TestCase
):
# Build a transformer network to use within the TEAMS trainer.
def
_get_network
(
self
,
vocab_size
):
sequence_length
=
512
hidden_size
=
50
embedding_cfg
=
{
'vocab_size'
:
vocab_size
,
'type_vocab_size'
:
1
,
'hidden_size'
:
hidden_size
,
'embedding_width'
:
hidden_size
,
'max_seq_length'
:
sequence_length
,
'initializer'
:
tf
.
keras
.
initializers
.
TruncatedNormal
(
stddev
=
0.02
),
'dropout_rate'
:
0.1
,
}
embedding_inst
=
packed_sequence_embedding
.
PackedSequenceEmbedding
(
**
embedding_cfg
)
hidden_cfg
=
{
'num_attention_heads'
:
2
,
'intermediate_size'
:
3072
,
'intermediate_activation'
:
activations
.
gelu
,
'dropout_rate'
:
0.1
,
'attention_dropout_rate'
:
0.1
,
'kernel_initializer'
:
tf
.
keras
.
initializers
.
TruncatedNormal
(
stddev
=
0.02
),
}
return
encoder_scaffold
.
EncoderScaffold
(
num_hidden_instances
=
2
,
pooled_output_dim
=
hidden_size
,
embedding_cfg
=
embedding_cfg
,
embedding_cls
=
embedding_inst
,
hidden_cfg
=
hidden_cfg
,
dict_outputs
=
True
)
def
test_teams_pretrainer
(
self
):
"""Validate that the Keras object can be created."""
vocab_size
=
100
test_generator_network
=
self
.
_get_network
(
vocab_size
)
test_discriminator_network
=
self
.
_get_network
(
vocab_size
)
# Create a TEAMS trainer with the created network.
candidate_size
=
3
teams_trainer_model
=
teams_pretrainer
.
TeamsPretrainer
(
generator_network
=
test_generator_network
,
discriminator_mws_network
=
test_discriminator_network
,
num_discriminator_task_agnostic_layers
=
1
,
vocab_size
=
vocab_size
,
candidate_size
=
candidate_size
)
# Create a set of 2-dimensional inputs (the first dimension is implicit).
num_token_predictions
=
2
sequence_length
=
128
word_ids
=
tf
.
keras
.
Input
(
shape
=
(
sequence_length
,),
dtype
=
tf
.
int32
)
mask
=
tf
.
keras
.
Input
(
shape
=
(
sequence_length
,),
dtype
=
tf
.
int32
)
type_ids
=
tf
.
keras
.
Input
(
shape
=
(
sequence_length
,),
dtype
=
tf
.
int32
)
lm_positions
=
tf
.
keras
.
Input
(
shape
=
(
num_token_predictions
,),
dtype
=
tf
.
int32
)
lm_ids
=
tf
.
keras
.
Input
(
shape
=
(
num_token_predictions
,),
dtype
=
tf
.
int32
)
inputs
=
{
'input_word_ids'
:
word_ids
,
'input_mask'
:
mask
,
'input_type_ids'
:
type_ids
,
'masked_lm_positions'
:
lm_positions
,
'masked_lm_ids'
:
lm_ids
}
# Invoke the trainer model on the inputs. This causes the layer to be built.
outputs
=
teams_trainer_model
(
inputs
)
lm_outs
=
outputs
[
'lm_outputs'
]
disc_rtd_logits
=
outputs
[
'disc_rtd_logits'
]
disc_rtd_label
=
outputs
[
'disc_rtd_label'
]
disc_mws_logits
=
outputs
[
'disc_mws_logits'
]
disc_mws_label
=
outputs
[
'disc_mws_label'
]
# Validate that the outputs are of the expected shape.
expected_lm_shape
=
[
None
,
num_token_predictions
,
vocab_size
]
expected_disc_rtd_logits_shape
=
[
None
,
sequence_length
]
expected_disc_rtd_label_shape
=
[
None
,
sequence_length
]
expected_disc_disc_mws_logits_shape
=
[
None
,
num_token_predictions
,
candidate_size
]
expected_disc_disc_mws_label_shape
=
[
None
,
num_token_predictions
]
self
.
assertAllEqual
(
expected_lm_shape
,
lm_outs
.
shape
.
as_list
())
self
.
assertAllEqual
(
expected_disc_rtd_logits_shape
,
disc_rtd_logits
.
shape
.
as_list
())
self
.
assertAllEqual
(
expected_disc_rtd_label_shape
,
disc_rtd_label
.
shape
.
as_list
())
self
.
assertAllEqual
(
expected_disc_disc_mws_logits_shape
,
disc_mws_logits
.
shape
.
as_list
())
self
.
assertAllEqual
(
expected_disc_disc_mws_label_shape
,
disc_mws_label
.
shape
.
as_list
())
def
test_teams_trainer_tensor_call
(
self
):
"""Validate that the Keras object can be invoked."""
vocab_size
=
100
test_generator_network
=
self
.
_get_network
(
vocab_size
)
test_discriminator_network
=
self
.
_get_network
(
vocab_size
)
# Create a TEAMS trainer with the created network.
teams_trainer_model
=
teams_pretrainer
.
TeamsPretrainer
(
generator_network
=
test_generator_network
,
discriminator_mws_network
=
test_discriminator_network
,
num_discriminator_task_agnostic_layers
=
2
,
vocab_size
=
vocab_size
,
candidate_size
=
2
)
# Create a set of 2-dimensional data tensors to feed into the model.
word_ids
=
tf
.
constant
([[
1
,
1
,
1
],
[
2
,
2
,
2
]],
dtype
=
tf
.
int32
)
mask
=
tf
.
constant
([[
1
,
1
,
1
],
[
1
,
0
,
0
]],
dtype
=
tf
.
int32
)
type_ids
=
tf
.
constant
([[
1
,
1
,
1
],
[
2
,
2
,
2
]],
dtype
=
tf
.
int32
)
lm_positions
=
tf
.
constant
([[
0
,
1
],
[
0
,
2
]],
dtype
=
tf
.
int32
)
lm_ids
=
tf
.
constant
([[
10
,
20
],
[
20
,
30
]],
dtype
=
tf
.
int32
)
inputs
=
{
'input_word_ids'
:
word_ids
,
'input_mask'
:
mask
,
'input_type_ids'
:
type_ids
,
'masked_lm_positions'
:
lm_positions
,
'masked_lm_ids'
:
lm_ids
}
# Invoke the trainer model on the tensors. In Eager mode, this does the
# actual calculation. (We can't validate the outputs, since the network is
# too complex: this simply ensures we're not hitting runtime errors.)
_
=
teams_trainer_model
(
inputs
)
def
test_serialize_deserialize
(
self
):
"""Validate that the TEAMS trainer can be serialized and deserialized."""
vocab_size
=
100
test_generator_network
=
self
.
_get_network
(
vocab_size
)
test_discriminator_network
=
self
.
_get_network
(
vocab_size
)
# Create a TEAMS trainer with the created network. (Note that all the args
# are different, so we can catch any serialization mismatches.)
teams_trainer_model
=
teams_pretrainer
.
TeamsPretrainer
(
generator_network
=
test_generator_network
,
discriminator_mws_network
=
test_discriminator_network
,
num_discriminator_task_agnostic_layers
=
2
,
vocab_size
=
vocab_size
,
candidate_size
=
2
)
# Create another TEAMS trainer via serialization and deserialization.
config
=
teams_trainer_model
.
get_config
()
new_teams_trainer_model
=
teams_pretrainer
.
TeamsPretrainer
.
from_config
(
config
)
# Validate that the config can be forced to JSON.
_
=
new_teams_trainer_model
.
to_json
()
# If the serialization was successful, the new config should match the old.
self
.
assertAllEqual
(
teams_trainer_model
.
get_config
(),
new_teams_trainer_model
.
get_config
())
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
official/nlp/projects/teams/teams_task.py
0 → 100644
View file @
ca552843
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""TEAMS pretraining task (Joint Masked LM, Replaced Token Detection and )."""
import
dataclasses
import
tensorflow
as
tf
from
official.core
import
base_task
from
official.core
import
config_definitions
as
cfg
from
official.core
import
task_factory
from
official.modeling
import
tf_utils
from
official.nlp.data
import
pretrain_dataloader
from
official.nlp.modeling
import
layers
from
official.nlp.projects.teams
import
teams
from
official.nlp.projects.teams
import
teams_pretrainer
@
dataclasses
.
dataclass
class
TeamsPretrainTaskConfig
(
cfg
.
TaskConfig
):
"""The model config."""
model
:
teams
.
TeamsPretrainerConfig
=
teams
.
TeamsPretrainerConfig
()
train_data
:
cfg
.
DataConfig
=
cfg
.
DataConfig
()
validation_data
:
cfg
.
DataConfig
=
cfg
.
DataConfig
()
def
_get_generator_hidden_layers
(
discriminator_network
,
num_hidden_layers
,
num_shared_layers
):
if
num_shared_layers
<=
0
:
num_shared_layers
=
0
hidden_layers
=
[]
else
:
hidden_layers
=
discriminator_network
.
hidden_layers
[:
num_shared_layers
]
for
_
in
range
(
num_shared_layers
,
num_hidden_layers
):
hidden_layers
.
append
(
layers
.
Transformer
)
return
hidden_layers
def
_build_pretrainer
(
config
:
teams
.
TeamsPretrainerConfig
)
->
teams_pretrainer
.
TeamsPretrainer
:
"""Instantiates TeamsPretrainer from the config."""
generator_encoder_cfg
=
config
.
generator
discriminator_encoder_cfg
=
config
.
discriminator
discriminator_network
=
teams
.
get_encoder
(
discriminator_encoder_cfg
)
# Copy discriminator's embeddings to generator for easier model serialization.
hidden_layers
=
_get_generator_hidden_layers
(
discriminator_network
,
generator_encoder_cfg
.
num_layers
,
config
.
num_shared_generator_hidden_layers
)
if
config
.
tie_embeddings
:
generator_network
=
teams
.
get_encoder
(
generator_encoder_cfg
,
embedding_network
=
discriminator_network
.
embedding_network
,
hidden_layers
=
hidden_layers
)
else
:
generator_network
=
teams
.
get_encoder
(
generator_encoder_cfg
,
hidden_layers
=
hidden_layers
)
return
teams_pretrainer
.
TeamsPretrainer
(
generator_network
=
generator_network
,
discriminator_mws_network
=
discriminator_network
,
num_discriminator_task_agnostic_layers
=
config
.
num_discriminator_task_agnostic_layers
,
vocab_size
=
generator_encoder_cfg
.
vocab_size
,
candidate_size
=
config
.
candidate_size
,
mlm_activation
=
tf_utils
.
get_activation
(
generator_encoder_cfg
.
hidden_activation
),
mlm_initializer
=
tf
.
keras
.
initializers
.
TruncatedNormal
(
stddev
=
generator_encoder_cfg
.
initializer_range
))
@
task_factory
.
register_task_cls
(
TeamsPretrainTaskConfig
)
class
TeamsPretrainTask
(
base_task
.
Task
):
"""TEAMS Pretrain Task (Masked LM + RTD + MWS)."""
def
build_model
(
self
):
return
_build_pretrainer
(
self
.
task_config
.
model
)
def
build_losses
(
self
,
labels
,
model_outputs
,
metrics
,
aux_losses
=
None
)
->
tf
.
Tensor
:
with
tf
.
name_scope
(
'TeamsPretrainTask/losses'
):
metrics
=
dict
([(
metric
.
name
,
metric
)
for
metric
in
metrics
])
# Generator MLM loss.
lm_prediction_losses
=
tf
.
keras
.
losses
.
sparse_categorical_crossentropy
(
labels
[
'masked_lm_ids'
],
tf
.
cast
(
model_outputs
[
'lm_outputs'
],
tf
.
float32
),
from_logits
=
True
)
lm_label_weights
=
labels
[
'masked_lm_weights'
]
lm_numerator_loss
=
tf
.
reduce_sum
(
lm_prediction_losses
*
lm_label_weights
)
lm_denominator_loss
=
tf
.
reduce_sum
(
lm_label_weights
)
mlm_loss
=
tf
.
math
.
divide_no_nan
(
lm_numerator_loss
,
lm_denominator_loss
)
metrics
[
'masked_lm_loss'
].
update_state
(
mlm_loss
)
weight
=
self
.
task_config
.
model
.
generator_loss_weight
total_loss
=
weight
*
mlm_loss
# Discriminator RTD loss.
rtd_logits
=
model_outputs
[
'disc_rtd_logits'
]
rtd_labels
=
tf
.
cast
(
model_outputs
[
'disc_rtd_label'
],
tf
.
float32
)
input_mask
=
tf
.
cast
(
labels
[
'input_mask'
],
tf
.
float32
)
rtd_ind_loss
=
tf
.
nn
.
sigmoid_cross_entropy_with_logits
(
logits
=
rtd_logits
,
labels
=
rtd_labels
)
rtd_numerator
=
tf
.
reduce_sum
(
input_mask
*
rtd_ind_loss
)
rtd_denominator
=
tf
.
reduce_sum
(
input_mask
)
rtd_loss
=
tf
.
math
.
divide_no_nan
(
rtd_numerator
,
rtd_denominator
)
metrics
[
'replaced_token_detection_loss'
].
update_state
(
rtd_loss
)
weight
=
self
.
task_config
.
model
.
discriminator_rtd_loss_weight
total_loss
=
total_loss
+
weight
*
rtd_loss
# Discriminator MWS loss.
mws_logits
=
model_outputs
[
'disc_mws_logits'
]
mws_labels
=
model_outputs
[
'disc_mws_label'
]
mws_loss
=
tf
.
keras
.
losses
.
sparse_categorical_crossentropy
(
mws_labels
,
mws_logits
,
from_logits
=
True
)
mws_numerator_loss
=
tf
.
reduce_sum
(
mws_loss
*
lm_label_weights
)
mws_denominator_loss
=
tf
.
reduce_sum
(
lm_label_weights
)
mws_loss
=
tf
.
math
.
divide_no_nan
(
mws_numerator_loss
,
mws_denominator_loss
)
metrics
[
'multiword_selection_loss'
].
update_state
(
mws_loss
)
weight
=
self
.
task_config
.
model
.
discriminator_mws_loss_weight
total_loss
=
total_loss
+
weight
*
mws_loss
if
aux_losses
:
total_loss
+=
tf
.
add_n
(
aux_losses
)
metrics
[
'total_loss'
].
update_state
(
total_loss
)
return
total_loss
def
build_inputs
(
self
,
params
,
input_context
=
None
):
"""Returns tf.data.Dataset for pretraining."""
if
params
.
input_path
==
'dummy'
:
def
dummy_data
(
_
):
dummy_ids
=
tf
.
zeros
((
1
,
params
.
seq_length
),
dtype
=
tf
.
int32
)
dummy_lm
=
tf
.
zeros
((
1
,
params
.
max_predictions_per_seq
),
dtype
=
tf
.
int32
)
return
dict
(
input_word_ids
=
dummy_ids
,
input_mask
=
dummy_ids
,
input_type_ids
=
dummy_ids
,
masked_lm_positions
=
dummy_lm
,
masked_lm_ids
=
dummy_lm
,
masked_lm_weights
=
tf
.
cast
(
dummy_lm
,
dtype
=
tf
.
float32
))
dataset
=
tf
.
data
.
Dataset
.
range
(
1
)
dataset
=
dataset
.
repeat
()
dataset
=
dataset
.
map
(
dummy_data
,
num_parallel_calls
=
tf
.
data
.
experimental
.
AUTOTUNE
)
return
dataset
return
pretrain_dataloader
.
BertPretrainDataLoader
(
params
).
load
(
input_context
)
def
build_metrics
(
self
,
training
=
None
):
del
training
metrics
=
[
tf
.
keras
.
metrics
.
SparseCategoricalAccuracy
(
name
=
'masked_lm_accuracy'
),
tf
.
keras
.
metrics
.
Mean
(
name
=
'masked_lm_loss'
),
tf
.
keras
.
metrics
.
SparseCategoricalAccuracy
(
name
=
'replaced_token_detection_accuracy'
),
tf
.
keras
.
metrics
.
Mean
(
name
=
'replaced_token_detection_loss'
),
tf
.
keras
.
metrics
.
SparseCategoricalAccuracy
(
name
=
'multiword_selection_accuracy'
),
tf
.
keras
.
metrics
.
Mean
(
name
=
'multiword_selection_loss'
),
tf
.
keras
.
metrics
.
Mean
(
name
=
'total_loss'
),
]
return
metrics
def
process_metrics
(
self
,
metrics
,
labels
,
model_outputs
):
with
tf
.
name_scope
(
'TeamsPretrainTask/process_metrics'
):
metrics
=
dict
([(
metric
.
name
,
metric
)
for
metric
in
metrics
])
if
'masked_lm_accuracy'
in
metrics
:
metrics
[
'masked_lm_accuracy'
].
update_state
(
labels
[
'masked_lm_ids'
],
model_outputs
[
'lm_outputs'
],
labels
[
'masked_lm_weights'
])
if
'replaced_token_detection_accuracy'
in
metrics
:
rtd_logits_expanded
=
tf
.
expand_dims
(
model_outputs
[
'disc_rtd_logits'
],
-
1
)
rtd_full_logits
=
tf
.
concat
(
[
-
1.0
*
rtd_logits_expanded
,
rtd_logits_expanded
],
-
1
)
metrics
[
'replaced_token_detection_accuracy'
].
update_state
(
model_outputs
[
'disc_rtd_label'
],
rtd_full_logits
,
labels
[
'input_mask'
])
if
'multiword_selection_accuracy'
in
metrics
:
metrics
[
'multiword_selection_accuracy'
].
update_state
(
model_outputs
[
'disc_mws_label'
],
model_outputs
[
'disc_mws_logits'
],
labels
[
'masked_lm_weights'
])
def
train_step
(
self
,
inputs
,
model
:
tf
.
keras
.
Model
,
optimizer
:
tf
.
keras
.
optimizers
.
Optimizer
,
metrics
):
"""Does forward and backward.
Args:
inputs: a dictionary of input tensors.
model: the model, forward pass definition.
optimizer: the optimizer for this training step.
metrics: a nested structure of metrics objects.
Returns:
A dictionary of logs.
"""
with
tf
.
GradientTape
()
as
tape
:
outputs
=
model
(
inputs
,
training
=
True
)
# Computes per-replica loss.
loss
=
self
.
build_losses
(
labels
=
inputs
,
model_outputs
=
outputs
,
metrics
=
metrics
,
aux_losses
=
model
.
losses
)
# Scales loss as the default gradients allreduce performs sum inside the
# optimizer.
scaled_loss
=
loss
/
tf
.
distribute
.
get_strategy
().
num_replicas_in_sync
tvars
=
model
.
trainable_variables
grads
=
tape
.
gradient
(
scaled_loss
,
tvars
)
optimizer
.
apply_gradients
(
list
(
zip
(
grads
,
tvars
)))
self
.
process_metrics
(
metrics
,
inputs
,
outputs
)
return
{
self
.
loss
:
loss
}
def
validation_step
(
self
,
inputs
,
model
:
tf
.
keras
.
Model
,
metrics
):
"""Validatation step.
Args:
inputs: a dictionary of input tensors.
model: the keras.Model.
metrics: a nested structure of metrics objects.
Returns:
A dictionary of logs.
"""
outputs
=
model
(
inputs
,
training
=
False
)
loss
=
self
.
build_losses
(
labels
=
inputs
,
model_outputs
=
outputs
,
metrics
=
metrics
,
aux_losses
=
model
.
losses
)
self
.
process_metrics
(
metrics
,
inputs
,
outputs
)
return
{
self
.
loss
:
loss
}
official/nlp/projects/teams/teams_task_test.py
0 → 100644
View file @
ca552843
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for teams_task."""
from
absl.testing
import
parameterized
import
tensorflow
as
tf
from
official.nlp.configs
import
encoders
from
official.nlp.data
import
pretrain_dataloader
from
official.nlp.projects.teams
import
teams
from
official.nlp.projects.teams
import
teams_task
class
TeamsPretrainTaskTest
(
tf
.
test
.
TestCase
,
parameterized
.
TestCase
):
@
parameterized
.
parameters
((
1
,
1
),
(
0
,
1
),
(
0
,
0
),
(
1
,
0
))
def
test_task
(
self
,
num_shared_hidden_layers
,
num_task_agnostic_layers
):
config
=
teams_task
.
TeamsPretrainTaskConfig
(
model
=
teams
.
TeamsPretrainerConfig
(
generator
=
encoders
.
BertEncoderConfig
(
vocab_size
=
30522
,
num_layers
=
2
),
discriminator
=
encoders
.
BertEncoderConfig
(
vocab_size
=
30522
,
num_layers
=
2
),
num_shared_generator_hidden_layers
=
num_shared_hidden_layers
,
num_discriminator_task_agnostic_layers
=
num_task_agnostic_layers
,
),
train_data
=
pretrain_dataloader
.
BertPretrainDataConfig
(
input_path
=
"dummy"
,
max_predictions_per_seq
=
20
,
seq_length
=
128
,
global_batch_size
=
1
))
task
=
teams_task
.
TeamsPretrainTask
(
config
)
model
=
task
.
build_model
()
metrics
=
task
.
build_metrics
()
dataset
=
task
.
build_inputs
(
config
.
train_data
)
iterator
=
iter
(
dataset
)
optimizer
=
tf
.
keras
.
optimizers
.
SGD
(
lr
=
0.1
)
task
.
train_step
(
next
(
iterator
),
model
,
optimizer
,
metrics
=
metrics
)
task
.
validation_step
(
next
(
iterator
),
model
,
metrics
=
metrics
)
if
__name__
==
"__main__"
:
tf
.
test
.
main
()
official/nlp/projects/triviaqa/inputs.py
View file @
ca552843
...
@@ -48,15 +48,15 @@ def _flatten_dims(tensor: tf.Tensor,
...
@@ -48,15 +48,15 @@ def _flatten_dims(tensor: tf.Tensor,
rank
=
tensor
.
shape
.
rank
rank
=
tensor
.
shape
.
rank
if
rank
is
None
:
if
rank
is
None
:
raise
ValueError
(
'Static rank of `tensor` must be known.'
)
raise
ValueError
(
'Static rank of `tensor` must be known.'
)
if
first_dim
<
0
:
if
first_dim
<
0
:
# pytype: disable=unsupported-operands
first_dim
+=
rank
first_dim
+=
rank
if
first_dim
<
0
or
first_dim
>=
rank
:
if
first_dim
<
0
or
first_dim
>=
rank
:
# pytype: disable=unsupported-operands
raise
ValueError
(
'`first_dim` out of bounds for `tensor` rank.'
)
raise
ValueError
(
'`first_dim` out of bounds for `tensor` rank.'
)
if
last_dim
<
0
:
if
last_dim
<
0
:
# pytype: disable=unsupported-operands
last_dim
+=
rank
last_dim
+=
rank
if
last_dim
<
0
or
last_dim
>=
rank
:
if
last_dim
<
0
or
last_dim
>=
rank
:
# pytype: disable=unsupported-operands
raise
ValueError
(
'`last_dim` out of bounds for `tensor` rank.'
)
raise
ValueError
(
'`last_dim` out of bounds for `tensor` rank.'
)
if
first_dim
>
last_dim
:
if
first_dim
>
last_dim
:
# pytype: disable=unsupported-operands
raise
ValueError
(
'`first_dim` must not be larger than `last_dim`.'
)
raise
ValueError
(
'`first_dim` must not be larger than `last_dim`.'
)
# Try to calculate static flattened dim size if all input sizes to flatten
# Try to calculate static flattened dim size if all input sizes to flatten
...
...
official/nlp/serving/export_savedmodel_util.py
View file @
ca552843
...
@@ -13,12 +13,19 @@
...
@@ -13,12 +13,19 @@
# limitations under the License.
# limitations under the License.
"""Common library to export a SavedModel from the export module."""
"""Common library to export a SavedModel from the export module."""
import
os
import
time
from
typing
import
Dict
,
List
,
Optional
,
Text
,
Union
from
typing
import
Dict
,
List
,
Optional
,
Text
,
Union
from
absl
import
logging
import
tensorflow
as
tf
import
tensorflow
as
tf
from
official.core
import
export_base
from
official.core
import
export_base
MAX_DIRECTORY_CREATION_ATTEMPTS
=
10
def
export
(
export_module
:
export_base
.
ExportModule
,
def
export
(
export_module
:
export_base
.
ExportModule
,
function_keys
:
Union
[
List
[
Text
],
Dict
[
Text
,
Text
]],
function_keys
:
Union
[
List
[
Text
],
Dict
[
Text
,
Text
]],
export_savedmodel_dir
:
Text
,
export_savedmodel_dir
:
Text
,
...
@@ -39,7 +46,39 @@ def export(export_module: export_base.ExportModule,
...
@@ -39,7 +46,39 @@ def export(export_module: export_base.ExportModule,
The savedmodel directory path.
The savedmodel directory path.
"""
"""
save_options
=
tf
.
saved_model
.
SaveOptions
(
function_aliases
=
{
save_options
=
tf
.
saved_model
.
SaveOptions
(
function_aliases
=
{
"
tpu_candidate
"
:
export_module
.
serve
,
'
tpu_candidate
'
:
export_module
.
serve
,
})
})
return
export_base
.
export
(
export_module
,
function_keys
,
export_savedmodel_dir
,
return
export_base
.
export
(
export_module
,
function_keys
,
export_savedmodel_dir
,
checkpoint_path
,
timestamped
,
save_options
)
checkpoint_path
,
timestamped
,
save_options
)
def
get_timestamped_export_dir
(
export_dir_base
):
"""Builds a path to a new subdirectory within the base directory.
Args:
export_dir_base: A string containing a directory to write the exported graph
and checkpoints.
Returns:
The full path of the new subdirectory (which is not actually created yet).
Raises:
RuntimeError: if repeated attempts fail to obtain a unique timestamped
directory name.
"""
attempts
=
0
while
attempts
<
MAX_DIRECTORY_CREATION_ATTEMPTS
:
timestamp
=
int
(
time
.
time
())
result_dir
=
os
.
path
.
join
(
export_dir_base
,
str
(
timestamp
))
if
not
tf
.
io
.
gfile
.
exists
(
result_dir
):
# Collisions are still possible (though extremely unlikely): this
# directory is not actually created yet, but it will be almost
# instantly on return from this function.
return
result_dir
time
.
sleep
(
1
)
attempts
+=
1
logging
.
warning
(
'Directory %s already exists; retrying (attempt %s/%s)'
,
str
(
result_dir
),
attempts
,
MAX_DIRECTORY_CREATION_ATTEMPTS
)
raise
RuntimeError
(
'Failed to obtain a unique export directory name after '
f
'
{
MAX_DIRECTORY_CREATION_ATTEMPTS
}
attempts.'
)
official/nlp/serving/serving_modules.py
View file @
ca552843
...
@@ -80,11 +80,10 @@ class SentencePrediction(export_base.ExportModule):
...
@@ -80,11 +80,10 @@ class SentencePrediction(export_base.ExportModule):
lower_case
=
params
.
lower_case
,
lower_case
=
params
.
lower_case
,
preprocessing_hub_module_url
=
params
.
preprocessing_hub_module_url
)
preprocessing_hub_module_url
=
params
.
preprocessing_hub_module_url
)
@
tf
.
function
def
_serve_tokenized_input
(
self
,
def
serve
(
self
,
input_word_ids
,
input_word_ids
,
input_mask
=
None
,
input_mask
=
None
,
input_type_ids
=
None
)
->
Dict
[
str
,
tf
.
Tensor
]
:
input_type_ids
=
None
)
->
tf
.
Tensor
:
if
input_type_ids
is
None
:
if
input_type_ids
is
None
:
# Requires CLS token is the first token of inputs.
# Requires CLS token is the first token of inputs.
input_type_ids
=
tf
.
zeros_like
(
input_word_ids
)
input_type_ids
=
tf
.
zeros_like
(
input_word_ids
)
...
@@ -97,7 +96,26 @@ class SentencePrediction(export_base.ExportModule):
...
@@ -97,7 +96,26 @@ class SentencePrediction(export_base.ExportModule):
input_word_ids
=
input_word_ids
,
input_word_ids
=
input_word_ids
,
input_mask
=
input_mask
,
input_mask
=
input_mask
,
input_type_ids
=
input_type_ids
)
input_type_ids
=
input_type_ids
)
return
dict
(
outputs
=
self
.
inference_step
(
inputs
))
return
self
.
inference_step
(
inputs
)
@
tf
.
function
def
serve
(
self
,
input_word_ids
,
input_mask
=
None
,
input_type_ids
=
None
)
->
Dict
[
str
,
tf
.
Tensor
]:
return
dict
(
outputs
=
self
.
_serve_tokenized_input
(
input_word_ids
,
input_mask
,
input_type_ids
))
@
tf
.
function
def
serve_probability
(
self
,
input_word_ids
,
input_mask
=
None
,
input_type_ids
=
None
)
->
Dict
[
str
,
tf
.
Tensor
]:
return
dict
(
outputs
=
tf
.
nn
.
softmax
(
self
.
_serve_tokenized_input
(
input_word_ids
,
input_mask
,
input_type_ids
)))
@
tf
.
function
@
tf
.
function
def
serve_examples
(
self
,
inputs
)
->
Dict
[
str
,
tf
.
Tensor
]:
def
serve_examples
(
self
,
inputs
)
->
Dict
[
str
,
tf
.
Tensor
]:
...
...
official/nlp/tasks/sentence_prediction.py
View file @
ca552843
...
@@ -13,10 +13,10 @@
...
@@ -13,10 +13,10 @@
# limitations under the License.
# limitations under the License.
"""Sentence prediction (classification) task."""
"""Sentence prediction (classification) task."""
import
dataclasses
from
typing
import
List
,
Union
,
Optional
from
typing
import
List
,
Union
,
Optional
from
absl
import
logging
from
absl
import
logging
import
dataclasses
import
numpy
as
np
import
numpy
as
np
import
orbit
import
orbit
from
scipy
import
stats
from
scipy
import
stats
...
@@ -140,14 +140,25 @@ class SentencePredictionTask(base_task.Task):
...
@@ -140,14 +140,25 @@ class SentencePredictionTask(base_task.Task):
del
training
del
training
if
self
.
task_config
.
model
.
num_classes
==
1
:
if
self
.
task_config
.
model
.
num_classes
==
1
:
metrics
=
[
tf
.
keras
.
metrics
.
MeanSquaredError
()]
metrics
=
[
tf
.
keras
.
metrics
.
MeanSquaredError
()]
elif
self
.
task_config
.
model
.
num_classes
==
2
:
metrics
=
[
tf
.
keras
.
metrics
.
SparseCategoricalAccuracy
(
name
=
'cls_accuracy'
),
tf
.
keras
.
metrics
.
AUC
(
name
=
'auc'
,
curve
=
'PR'
),
]
else
:
else
:
metrics
=
[
metrics
=
[
tf
.
keras
.
metrics
.
SparseCategoricalAccuracy
(
name
=
'cls_accuracy'
)
tf
.
keras
.
metrics
.
SparseCategoricalAccuracy
(
name
=
'cls_accuracy'
)
,
]
]
return
metrics
return
metrics
def
process_metrics
(
self
,
metrics
,
labels
,
model_outputs
):
def
process_metrics
(
self
,
metrics
,
labels
,
model_outputs
):
for
metric
in
metrics
:
for
metric
in
metrics
:
if
metric
.
name
==
'auc'
:
# Convert the logit to probability and extract the probability of True..
metric
.
update_state
(
labels
[
self
.
label_field
],
tf
.
expand_dims
(
tf
.
nn
.
softmax
(
model_outputs
)[:,
1
],
axis
=
1
))
if
metric
.
name
==
'cls_accuracy'
:
metric
.
update_state
(
labels
[
self
.
label_field
],
model_outputs
)
metric
.
update_state
(
labels
[
self
.
label_field
],
model_outputs
)
def
process_compiled_metrics
(
self
,
compiled_metrics
,
labels
,
model_outputs
):
def
process_compiled_metrics
(
self
,
compiled_metrics
,
labels
,
model_outputs
):
...
...
official/nlp/tasks/translation.py
View file @
ca552843
...
@@ -13,11 +13,11 @@
...
@@ -13,11 +13,11 @@
# limitations under the License.
# limitations under the License.
"""Defines the translation task."""
"""Defines the translation task."""
import
dataclasses
import
os
import
os
from
typing
import
Optional
from
typing
import
Optional
from
absl
import
logging
from
absl
import
logging
import
dataclasses
import
sacrebleu
import
sacrebleu
import
tensorflow
as
tf
import
tensorflow
as
tf
import
tensorflow_text
as
tftxt
import
tensorflow_text
as
tftxt
...
...
official/nlp/tasks/translation_test.py
View file @
ca552843
...
@@ -85,7 +85,8 @@ class TranslationTaskTest(tf.test.TestCase):
...
@@ -85,7 +85,8 @@ class TranslationTaskTest(tf.test.TestCase):
def
test_task
(
self
):
def
test_task
(
self
):
config
=
translation
.
TranslationConfig
(
config
=
translation
.
TranslationConfig
(
model
=
translation
.
ModelConfig
(
model
=
translation
.
ModelConfig
(
encoder
=
translation
.
EncDecoder
(),
decoder
=
translation
.
EncDecoder
()),
encoder
=
translation
.
EncDecoder
(
num_layers
=
1
),
decoder
=
translation
.
EncDecoder
(
num_layers
=
1
)),
train_data
=
wmt_dataloader
.
WMTDataConfig
(
train_data
=
wmt_dataloader
.
WMTDataConfig
(
input_path
=
self
.
_record_input_path
,
input_path
=
self
.
_record_input_path
,
src_lang
=
"en"
,
tgt_lang
=
"reverse_en"
,
src_lang
=
"en"
,
tgt_lang
=
"reverse_en"
,
...
@@ -102,7 +103,8 @@ class TranslationTaskTest(tf.test.TestCase):
...
@@ -102,7 +103,8 @@ class TranslationTaskTest(tf.test.TestCase):
def
test_no_sentencepiece_path
(
self
):
def
test_no_sentencepiece_path
(
self
):
config
=
translation
.
TranslationConfig
(
config
=
translation
.
TranslationConfig
(
model
=
translation
.
ModelConfig
(
model
=
translation
.
ModelConfig
(
encoder
=
translation
.
EncDecoder
(),
decoder
=
translation
.
EncDecoder
()),
encoder
=
translation
.
EncDecoder
(
num_layers
=
1
),
decoder
=
translation
.
EncDecoder
(
num_layers
=
1
)),
train_data
=
wmt_dataloader
.
WMTDataConfig
(
train_data
=
wmt_dataloader
.
WMTDataConfig
(
input_path
=
self
.
_record_input_path
,
input_path
=
self
.
_record_input_path
,
src_lang
=
"en"
,
tgt_lang
=
"reverse_en"
,
src_lang
=
"en"
,
tgt_lang
=
"reverse_en"
,
...
@@ -122,7 +124,8 @@ class TranslationTaskTest(tf.test.TestCase):
...
@@ -122,7 +124,8 @@ class TranslationTaskTest(tf.test.TestCase):
sentencepeice_model_prefix
)
sentencepeice_model_prefix
)
config
=
translation
.
TranslationConfig
(
config
=
translation
.
TranslationConfig
(
model
=
translation
.
ModelConfig
(
model
=
translation
.
ModelConfig
(
encoder
=
translation
.
EncDecoder
(),
decoder
=
translation
.
EncDecoder
()),
encoder
=
translation
.
EncDecoder
(
num_layers
=
1
),
decoder
=
translation
.
EncDecoder
(
num_layers
=
1
)),
train_data
=
wmt_dataloader
.
WMTDataConfig
(
train_data
=
wmt_dataloader
.
WMTDataConfig
(
input_path
=
self
.
_record_input_path
,
input_path
=
self
.
_record_input_path
,
src_lang
=
"en"
,
tgt_lang
=
"reverse_en"
,
src_lang
=
"en"
,
tgt_lang
=
"reverse_en"
,
...
@@ -137,7 +140,8 @@ class TranslationTaskTest(tf.test.TestCase):
...
@@ -137,7 +140,8 @@ class TranslationTaskTest(tf.test.TestCase):
def
test_evaluation
(
self
):
def
test_evaluation
(
self
):
config
=
translation
.
TranslationConfig
(
config
=
translation
.
TranslationConfig
(
model
=
translation
.
ModelConfig
(
model
=
translation
.
ModelConfig
(
encoder
=
translation
.
EncDecoder
(),
decoder
=
translation
.
EncDecoder
(),
encoder
=
translation
.
EncDecoder
(
num_layers
=
1
),
decoder
=
translation
.
EncDecoder
(
num_layers
=
1
),
padded_decode
=
False
,
padded_decode
=
False
,
decode_max_length
=
64
),
decode_max_length
=
64
),
validation_data
=
wmt_dataloader
.
WMTDataConfig
(
validation_data
=
wmt_dataloader
.
WMTDataConfig
(
...
...
official/nlp/train.py
View file @
ca552843
...
@@ -27,9 +27,15 @@ from official.core import task_factory
...
@@ -27,9 +27,15 @@ from official.core import task_factory
from
official.core
import
train_lib
from
official.core
import
train_lib
from
official.core
import
train_utils
from
official.core
import
train_utils
from
official.modeling
import
performance
from
official.modeling
import
performance
from
official.nlp
import
continuous_finetune_lib
FLAGS
=
flags
.
FLAGS
FLAGS
=
flags
.
FLAGS
flags
.
DEFINE_integer
(
'pretrain_steps'
,
default
=
None
,
help
=
'The number of total training steps for the pretraining job.'
)
def
main
(
_
):
def
main
(
_
):
gin
.
parse_config_files_and_bindings
(
FLAGS
.
gin_file
,
FLAGS
.
gin_params
)
gin
.
parse_config_files_and_bindings
(
FLAGS
.
gin_file
,
FLAGS
.
gin_params
)
...
@@ -40,12 +46,18 @@ def main(_):
...
@@ -40,12 +46,18 @@ def main(_):
# may race against the train job for writing the same file.
# may race against the train job for writing the same file.
train_utils
.
serialize_config
(
params
,
model_dir
)
train_utils
.
serialize_config
(
params
,
model_dir
)
if
FLAGS
.
mode
==
'continuous_train_and_eval'
:
continuous_finetune_lib
.
run_continuous_finetune
(
FLAGS
.
mode
,
params
,
model_dir
,
pretrain_steps
=
FLAGS
.
pretrain_steps
)
else
:
# Sets mixed_precision policy. Using 'mixed_float16' or 'mixed_bfloat16'
# Sets mixed_precision policy. Using 'mixed_float16' or 'mixed_bfloat16'
# can have significant impact on model speeds by utilizing float16 in case
of
# can have significant impact on model speeds by utilizing float16 in case
# GPUs, and bfloat16 in the case of TPUs. loss_scale takes effect only
when
#
of
GPUs, and bfloat16 in the case of TPUs. loss_scale takes effect only
# dtype is float16
#
when
dtype is float16
if
params
.
runtime
.
mixed_precision_dtype
:
if
params
.
runtime
.
mixed_precision_dtype
:
performance
.
set_mixed_precision_policy
(
params
.
runtime
.
mixed_precision_dtype
)
performance
.
set_mixed_precision_policy
(
params
.
runtime
.
mixed_precision_dtype
)
distribution_strategy
=
distribute_utils
.
get_distribution_strategy
(
distribution_strategy
=
distribute_utils
.
get_distribution_strategy
(
distribution_strategy
=
params
.
runtime
.
distribution_strategy
,
distribution_strategy
=
params
.
runtime
.
distribution_strategy
,
all_reduce_alg
=
params
.
runtime
.
all_reduce_alg
,
all_reduce_alg
=
params
.
runtime
.
all_reduce_alg
,
...
...
official/pip_package/setup.py
View file @
ca552843
...
@@ -81,6 +81,7 @@ setup(
...
@@ -81,6 +81,7 @@ setup(
'official.pip_package*'
,
'official.pip_package*'
,
'official.benchmark*'
,
'official.benchmark*'
,
'official.colab*'
,
'official.colab*'
,
'official.recommendation.ranking.data.preprocessing*'
,
]),
]),
exclude_package_data
=
{
exclude_package_data
=
{
''
:
[
'*_test.py'
,],
''
:
[
'*_test.py'
,],
...
...
official/projects/basnet/README.md
0 → 100644
View file @
ca552843
# BASNet: Boundary-Aware Salient Object Detection
This repository is the unofficial implementation of the following paper. Please
see the paper
[
BASNet: Boundary-Aware Salient Object Detection
](
https://openaccess.thecvf.com/content_CVPR_2019/html/Qin_BASNet_Boundary-Aware_Salient_Object_Detection_CVPR_2019_paper.html
)
for more details.
## Requirements
[

](https://github.com/tensorflow/tensorflow/releases/tag/v2.4.0)
[

](https://www.python.org/downloads/release/python-379/)
## Train
```
shell
$
python3 train.py
\
--experiment
=
basnet_duts
\
--mode
=
train
\
--model_dir
=
$MODEL_DIR
\
--config_file
=
./configs/experiments/basnet_dut_gpu.yaml
```
## Test
```
shell
$
python3 train.py
\
--experiment
=
basnet_duts
\
--mode
=
eval
\
--model_dir
=
$MODEL_DIR
\
--config_file
=
./configs/experiments/basnet_dut_gpu.yaml
--params_override
=
'runtime.num_gpus=1, runtime.distribution_strategy=one_device, task.model.input_size=[256, 256, 3]'
```
## Results
Dataset | maxF
<sub>
β
</sub>
| relaxF
<sub>
β
</sub>
| MAE
:--------- | :--------------- | :------------------- | -------:
DUTS-TE | 0.865 | 0.793 | 0.046
official/projects/basnet/configs/basnet.py
0 → 100644
View file @
ca552843
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""BASNet configuration definition."""
import
dataclasses
import
os
from
typing
import
List
,
Optional
,
Union
from
official.core
import
exp_factory
from
official.modeling
import
hyperparams
from
official.modeling
import
optimization
from
official.modeling.hyperparams
import
config_definitions
as
cfg
from
official.vision.beta.configs
import
common
@
dataclasses
.
dataclass
class
DataConfig
(
cfg
.
DataConfig
):
"""Input config for training."""
output_size
:
List
[
int
]
=
dataclasses
.
field
(
default_factory
=
list
)
# If crop_size is specified, image will be resized first to
# output_size, then crop of size crop_size will be cropped.
crop_size
:
List
[
int
]
=
dataclasses
.
field
(
default_factory
=
list
)
input_path
:
str
=
''
global_batch_size
:
int
=
0
is_training
:
bool
=
True
dtype
:
str
=
'float32'
shuffle_buffer_size
:
int
=
1000
cycle_length
:
int
=
10
resize_eval_groundtruth
:
bool
=
True
groundtruth_padded_size
:
List
[
int
]
=
dataclasses
.
field
(
default_factory
=
list
)
aug_rand_hflip
:
bool
=
True
file_type
:
str
=
'tfrecord'
@
dataclasses
.
dataclass
class
BASNetModel
(
hyperparams
.
Config
):
"""BASNet model config."""
input_size
:
List
[
int
]
=
dataclasses
.
field
(
default_factory
=
list
)
use_bias
:
bool
=
False
norm_activation
:
common
.
NormActivation
=
common
.
NormActivation
()
@
dataclasses
.
dataclass
class
Losses
(
hyperparams
.
Config
):
label_smoothing
:
float
=
0.1
ignore_label
:
int
=
0
# will be treated as background
l2_weight_decay
:
float
=
0.0
use_groundtruth_dimension
:
bool
=
True
@
dataclasses
.
dataclass
class
BASNetTask
(
cfg
.
TaskConfig
):
"""The model config."""
model
:
BASNetModel
=
BASNetModel
()
train_data
:
DataConfig
=
DataConfig
(
is_training
=
True
)
validation_data
:
DataConfig
=
DataConfig
(
is_training
=
False
)
losses
:
Losses
=
Losses
()
gradient_clip_norm
:
float
=
0.0
init_checkpoint
:
Optional
[
str
]
=
None
init_checkpoint_modules
:
Union
[
str
,
List
[
str
]]
=
'backbone'
# all, backbone, and/or decoder
@
exp_factory
.
register_config_factory
(
'basnet'
)
def
basnet
()
->
cfg
.
ExperimentConfig
:
"""BASNet general."""
return
cfg
.
ExperimentConfig
(
task
=
BASNetModel
(),
trainer
=
cfg
.
TrainerConfig
(),
restrictions
=
[
'task.train_data.is_training != None'
,
'task.validation_data.is_training != None'
])
# DUTS Dataset
DUTS_TRAIN_EXAMPLES
=
10553
DUTS_VAL_EXAMPLES
=
5019
DUTS_INPUT_PATH_BASE_TR
=
'DUTS_DATASET'
DUTS_INPUT_PATH_BASE_VAL
=
'DUTS_DATASET'
@
exp_factory
.
register_config_factory
(
'basnet_duts'
)
def
basnet_duts
()
->
cfg
.
ExperimentConfig
:
"""Image segmentation on duts with basnet."""
train_batch_size
=
64
eval_batch_size
=
16
steps_per_epoch
=
DUTS_TRAIN_EXAMPLES
//
train_batch_size
config
=
cfg
.
ExperimentConfig
(
task
=
BASNetTask
(
model
=
BASNetModel
(
input_size
=
[
None
,
None
,
3
],
use_bias
=
True
,
norm_activation
=
common
.
NormActivation
(
activation
=
'relu'
,
norm_momentum
=
0.99
,
norm_epsilon
=
1e-3
,
use_sync_bn
=
True
)),
losses
=
Losses
(
l2_weight_decay
=
0
),
train_data
=
DataConfig
(
input_path
=
os
.
path
.
join
(
DUTS_INPUT_PATH_BASE_TR
,
'tf_record_train'
),
file_type
=
'tfrecord'
,
crop_size
=
[
224
,
224
],
output_size
=
[
256
,
256
],
is_training
=
True
,
global_batch_size
=
train_batch_size
,
),
validation_data
=
DataConfig
(
input_path
=
os
.
path
.
join
(
DUTS_INPUT_PATH_BASE_VAL
,
'tf_record_test'
),
file_type
=
'tfrecord'
,
output_size
=
[
256
,
256
],
is_training
=
False
,
global_batch_size
=
eval_batch_size
,
),
init_checkpoint
=
'gs://cloud-basnet-checkpoints/basnet_encoder_imagenet/ckpt-340306'
,
init_checkpoint_modules
=
'backbone'
),
trainer
=
cfg
.
TrainerConfig
(
steps_per_loop
=
steps_per_epoch
,
summary_interval
=
steps_per_epoch
,
checkpoint_interval
=
steps_per_epoch
,
train_steps
=
300
*
steps_per_epoch
,
validation_steps
=
DUTS_VAL_EXAMPLES
//
eval_batch_size
,
validation_interval
=
steps_per_epoch
,
optimizer_config
=
optimization
.
OptimizationConfig
({
'optimizer'
:
{
'type'
:
'adam'
,
'adam'
:
{
'beta_1'
:
0.9
,
'beta_2'
:
0.999
,
'epsilon'
:
1e-8
,
}
},
'learning_rate'
:
{
'type'
:
'constant'
,
'constant'
:
{
'learning_rate'
:
0.001
}
}
})),
restrictions
=
[
'task.train_data.is_training != None'
,
'task.validation_data.is_training != None'
])
return
config
official/projects/basnet/configs/basnet_test.py
0 → 100644
View file @
ca552843
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for basnet configs."""
# pylint: disable=unused-import
from
absl.testing
import
parameterized
import
tensorflow
as
tf
from
official.core
import
exp_factory
from
official.modeling.hyperparams
import
config_definitions
as
cfg
from
official.projects.basnet.configs
import
basnet
as
exp_cfg
class
BASNetConfigTest
(
tf
.
test
.
TestCase
,
parameterized
.
TestCase
):
@
parameterized
.
parameters
((
'basnet_duts'
,))
def
test_basnet_configs
(
self
,
config_name
):
config
=
exp_factory
.
get_exp_config
(
config_name
)
self
.
assertIsInstance
(
config
,
cfg
.
ExperimentConfig
)
self
.
assertIsInstance
(
config
.
task
,
exp_cfg
.
BASNetTask
)
self
.
assertIsInstance
(
config
.
task
.
model
,
exp_cfg
.
BASNetModel
)
self
.
assertIsInstance
(
config
.
task
.
train_data
,
exp_cfg
.
DataConfig
)
config
.
task
.
train_data
.
is_training
=
None
with
self
.
assertRaises
(
KeyError
):
config
.
validate
()
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
official/projects/basnet/configs/experiments/basnet_dut_gpu.yaml
0 → 100644
View file @
ca552843
runtime
:
distribution_strategy
:
'
mirrored'
mixed_precision_dtype
:
'
float32'
num_gpus
:
8
task
:
train_data
:
dtype
:
'
float32'
validation_data
:
resize_eval_groundtruth
:
true
dtype
:
'
float32'
official/projects/basnet/evaluation/metrics.py
0 → 100644
View file @
ca552843
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Evaluation metrics for BASNet.
The MAE and maxFscore implementations are a modified version of
https://github.com/xuebinqin/Binary-Segmentation-Evaluation-Tool
"""
import
numpy
as
np
import
scipy.signal
class
MAE
:
"""Mean Absolute Error(MAE) metric for basnet."""
def
__init__
(
self
):
"""Constructs MAE metric class."""
self
.
reset_states
()
@
property
def
name
(
self
):
return
'MAE'
def
reset_states
(
self
):
"""Resets internal states for a fresh run."""
self
.
_predictions
=
[]
self
.
_groundtruths
=
[]
def
result
(
self
):
"""Evaluates segmentation results, and reset_states."""
metric_result
=
self
.
evaluate
()
# Cleans up the internal variables in order for a fresh eval next time.
self
.
reset_states
()
return
metric_result
def
evaluate
(
self
):
"""Evaluates with masks from all images.
Returns:
average_mae: average MAE with float numpy.
"""
mae_total
=
0.0
for
(
true
,
pred
)
in
zip
(
self
.
_groundtruths
,
self
.
_predictions
):
# Computes MAE
mae
=
self
.
_compute_mae
(
true
,
pred
)
mae_total
+=
mae
average_mae
=
mae_total
/
len
(
self
.
_groundtruths
)
return
average_mae
def
_mask_normalize
(
self
,
mask
):
return
mask
/
(
np
.
amax
(
mask
)
+
1e-8
)
def
_compute_mae
(
self
,
true
,
pred
):
h
,
w
=
true
.
shape
[
0
],
true
.
shape
[
1
]
mask1
=
self
.
_mask_normalize
(
true
)
mask2
=
self
.
_mask_normalize
(
pred
)
sum_error
=
np
.
sum
(
np
.
absolute
((
mask1
.
astype
(
float
)
-
mask2
.
astype
(
float
))))
mae_error
=
sum_error
/
(
float
(
h
)
*
float
(
w
)
+
1e-8
)
return
mae_error
def
_convert_to_numpy
(
self
,
groundtruths
,
predictions
):
"""Converts tesnors to numpy arrays."""
numpy_groundtruths
=
groundtruths
.
numpy
()
numpy_predictions
=
predictions
.
numpy
()
return
numpy_groundtruths
,
numpy_predictions
def
update_state
(
self
,
groundtruths
,
predictions
):
"""Update segmentation results and groundtruth data.
Args:
groundtruths : Tuple of single Tensor [batch, width, height, 1],
groundtruth masks. range [0, 1]
predictions : Tuple of single Tensor [batch, width, height, 1],
predicted masks. range [0, 1]
"""
groundtruths
,
predictions
=
self
.
_convert_to_numpy
(
groundtruths
[
0
],
predictions
[
0
])
for
(
true
,
pred
)
in
zip
(
groundtruths
,
predictions
):
self
.
_groundtruths
.
append
(
true
)
self
.
_predictions
.
append
(
pred
)
class
MaxFscore
:
"""Maximum F-score metric for basnet."""
def
__init__
(
self
):
"""Constructs BASNet evaluation class."""
self
.
reset_states
()
@
property
def
name
(
self
):
return
'MaxFScore'
def
reset_states
(
self
):
"""Resets internal states for a fresh run."""
self
.
_predictions
=
[]
self
.
_groundtruths
=
[]
def
result
(
self
):
"""Evaluates segmentation results, and reset_states."""
metric_result
=
self
.
evaluate
()
# Cleans up the internal variables in order for a fresh eval next time.
self
.
reset_states
()
return
metric_result
def
evaluate
(
self
):
"""Evaluates with masks from all images.
Returns:
f_max: maximum F-score value.
"""
mybins
=
np
.
arange
(
0
,
256
)
beta
=
0.3
precisions
=
np
.
zeros
((
len
(
self
.
_groundtruths
),
len
(
mybins
)
-
1
))
recalls
=
np
.
zeros
((
len
(
self
.
_groundtruths
),
len
(
mybins
)
-
1
))
for
i
,
(
true
,
pred
)
in
enumerate
(
zip
(
self
.
_groundtruths
,
self
.
_predictions
)):
# Compute F-score
true
=
self
.
_mask_normalize
(
true
)
*
255.0
pred
=
self
.
_mask_normalize
(
pred
)
*
255.0
pre
,
rec
=
self
.
_compute_pre_rec
(
true
,
pred
,
mybins
=
np
.
arange
(
0
,
256
))
precisions
[
i
,
:]
=
pre
recalls
[
i
,
:]
=
rec
precisions
=
np
.
sum
(
precisions
,
0
)
/
(
len
(
self
.
_groundtruths
)
+
1e-8
)
recalls
=
np
.
sum
(
recalls
,
0
)
/
(
len
(
self
.
_groundtruths
)
+
1e-8
)
f
=
(
1
+
beta
)
*
precisions
*
recalls
/
(
beta
*
precisions
+
recalls
+
1e-8
)
f_max
=
np
.
max
(
f
)
f_max
=
f_max
.
astype
(
np
.
float32
)
return
f_max
def
_mask_normalize
(
self
,
mask
):
return
mask
/
(
np
.
amax
(
mask
)
+
1e-8
)
def
_compute_pre_rec
(
self
,
true
,
pred
,
mybins
=
np
.
arange
(
0
,
256
)):
"""Computes relaxed precision and recall."""
# pixel number of ground truth foreground regions
gt_num
=
true
[
true
>
128
].
size
# mask predicted pixel values in the ground truth foreground region
pp
=
pred
[
true
>
128
]
# mask predicted pixel values in the ground truth bacground region
nn
=
pred
[
true
<=
128
]
pp_hist
,
_
=
np
.
histogram
(
pp
,
bins
=
mybins
)
nn_hist
,
_
=
np
.
histogram
(
nn
,
bins
=
mybins
)
pp_hist_flip
=
np
.
flipud
(
pp_hist
)
nn_hist_flip
=
np
.
flipud
(
nn_hist
)
pp_hist_flip_cum
=
np
.
cumsum
(
pp_hist_flip
)
nn_hist_flip_cum
=
np
.
cumsum
(
nn_hist_flip
)
precision
=
pp_hist_flip_cum
/
(
pp_hist_flip_cum
+
nn_hist_flip_cum
+
1e-8
)
# TP/(TP+FP)
recall
=
pp_hist_flip_cum
/
(
gt_num
+
1e-8
)
# TP/(TP+FN)
precision
[
np
.
isnan
(
precision
)]
=
0.0
recall
[
np
.
isnan
(
recall
)]
=
0.0
pre_len
=
len
(
precision
)
rec_len
=
len
(
recall
)
return
np
.
reshape
(
precision
,
(
pre_len
)),
np
.
reshape
(
recall
,
(
rec_len
))
def
_convert_to_numpy
(
self
,
groundtruths
,
predictions
):
"""Converts tesnors to numpy arrays."""
numpy_groundtruths
=
groundtruths
.
numpy
()
numpy_predictions
=
predictions
.
numpy
()
return
numpy_groundtruths
,
numpy_predictions
def
update_state
(
self
,
groundtruths
,
predictions
):
"""Update segmentation results and groundtruth data.
Args:
groundtruths : Tuple of single Tensor [batch, width, height, 1],
groundtruth masks. range [0, 1]
predictions : Tuple of signle Tensor [batch, width, height, 1],
predicted masks. range [0, 1]
"""
groundtruths
,
predictions
=
self
.
_convert_to_numpy
(
groundtruths
[
0
],
predictions
[
0
])
for
(
true
,
pred
)
in
zip
(
groundtruths
,
predictions
):
self
.
_groundtruths
.
append
(
true
)
self
.
_predictions
.
append
(
pred
)
class
RelaxedFscore
:
"""Relaxed F-score metric for basnet."""
def
__init__
(
self
):
"""Constructs BASNet evaluation class."""
self
.
reset_states
()
@
property
def
name
(
self
):
return
'RelaxFScore'
def
reset_states
(
self
):
"""Resets internal states for a fresh run."""
self
.
_predictions
=
[]
self
.
_groundtruths
=
[]
def
result
(
self
):
"""Evaluates segmentation results, and reset_states."""
metric_result
=
self
.
evaluate
()
# Cleans up the internal variables in order for a fresh eval next time.
self
.
reset_states
()
return
metric_result
def
evaluate
(
self
):
"""Evaluates with masks from all images.
Returns:
relax_f: relaxed F-score value.
"""
beta
=
0.3
rho
=
3
relax_fs
=
np
.
zeros
(
len
(
self
.
_groundtruths
))
erode_kernel
=
np
.
ones
((
3
,
3
))
for
i
,
(
true
,
pred
)
in
enumerate
(
zip
(
self
.
_groundtruths
,
self
.
_predictions
)):
true
=
self
.
_mask_normalize
(
true
)
pred
=
self
.
_mask_normalize
(
pred
)
true
=
np
.
squeeze
(
true
,
axis
=-
1
)
pred
=
np
.
squeeze
(
pred
,
axis
=-
1
)
# binary saliency mask (S_bw), threshold 0.5
pred
[
pred
>=
0.5
]
=
1
pred
[
pred
<
0.5
]
=
0
# compute eroded binary mask (S_erd) of S_bw
pred_erd
=
self
.
_compute_erosion
(
pred
,
erode_kernel
)
pred_xor
=
np
.
logical_xor
(
pred_erd
,
pred
)
# convert True/False to 1/0
pred_xor
=
pred_xor
*
1
# same method for ground truth
true
[
true
>=
0.5
]
=
1
true
[
true
<
0.5
]
=
0
true_erd
=
self
.
_compute_erosion
(
true
,
erode_kernel
)
true_xor
=
np
.
logical_xor
(
true_erd
,
true
)
true_xor
=
true_xor
*
1
pre
,
rec
=
self
.
_compute_relax_pre_rec
(
true_xor
,
pred_xor
,
rho
)
relax_fs
[
i
]
=
(
1
+
beta
)
*
pre
*
rec
/
(
beta
*
pre
+
rec
+
1e-8
)
relax_f
=
np
.
sum
(
relax_fs
,
0
)
/
(
len
(
self
.
_groundtruths
)
+
1e-8
)
relax_f
=
relax_f
.
astype
(
np
.
float32
)
return
relax_f
def
_mask_normalize
(
self
,
mask
):
return
mask
/
(
np
.
amax
(
mask
)
+
1e-8
)
def
_compute_erosion
(
self
,
mask
,
kernel
):
kernel_full
=
np
.
sum
(
kernel
)
mask_erd
=
scipy
.
signal
.
convolve2d
(
mask
,
kernel
,
mode
=
'same'
)
mask_erd
[
mask_erd
<
kernel_full
]
=
0
mask_erd
[
mask_erd
==
kernel_full
]
=
1
return
mask_erd
def
_compute_relax_pre_rec
(
self
,
true
,
pred
,
rho
):
"""Computes relaxed precision and recall."""
kernel
=
np
.
ones
((
2
*
rho
-
1
,
2
*
rho
-
1
))
map_zeros
=
np
.
zeros_like
(
pred
)
map_ones
=
np
.
ones_like
(
pred
)
pred_filtered
=
scipy
.
signal
.
convolve2d
(
pred
,
kernel
,
mode
=
'same'
)
# True positive for relaxed precision
relax_pre_tp
=
np
.
where
((
true
==
1
)
&
(
pred_filtered
>
0
),
map_ones
,
map_zeros
)
true_filtered
=
scipy
.
signal
.
convolve2d
(
true
,
kernel
,
mode
=
'same'
)
# True positive for relaxed recall
relax_rec_tp
=
np
.
where
((
pred
==
1
)
&
(
true_filtered
>
0
),
map_ones
,
map_zeros
)
return
np
.
sum
(
relax_pre_tp
)
/
np
.
sum
(
pred
),
np
.
sum
(
relax_rec_tp
)
/
np
.
sum
(
true
)
def
_convert_to_numpy
(
self
,
groundtruths
,
predictions
):
"""Converts tesnors to numpy arrays."""
numpy_groundtruths
=
groundtruths
.
numpy
()
numpy_predictions
=
predictions
.
numpy
()
return
numpy_groundtruths
,
numpy_predictions
def
update_state
(
self
,
groundtruths
,
predictions
):
"""Update segmentation results and groundtruth data.
Args:
groundtruths : Tuple of single Tensor [batch, width, height, 1],
groundtruth masks. range [0, 1]
predictions : Tuple of single Tensor [batch, width, height, 1],
predicted masks. range [0, 1]
"""
groundtruths
,
predictions
=
self
.
_convert_to_numpy
(
groundtruths
[
0
],
predictions
[
0
])
for
(
true
,
pred
)
in
zip
(
groundtruths
,
predictions
):
self
.
_groundtruths
.
append
(
true
)
self
.
_predictions
.
append
(
pred
)
Prev
1
2
3
4
5
6
7
8
…
15
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment