Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
8f5f819f
"git@developer.sourcefind.cn:OpenDAS/vision.git" did not exist on "355e9d2f29a7597c12d78753e3cc250e819d7198"
Commit
8f5f819f
authored
Jul 16, 2020
by
Kaushik Shivakumar
Browse files
Merge branch 'master' of
https://github.com/tensorflow/models
into latest
parents
7c062a56
709a6617
Changes
44
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
368 additions
and
102 deletions
+368
-102
community/README.md
community/README.md
+8
-0
official/README.md
official/README.md
+3
-6
official/benchmark/resnet_ctl_imagenet_benchmark.py
official/benchmark/resnet_ctl_imagenet_benchmark.py
+28
-8
official/modeling/hyperparams/config_definitions.py
official/modeling/hyperparams/config_definitions.py
+21
-9
official/nlp/configs/electra.py
official/nlp/configs/electra.py
+14
-5
official/nlp/configs/encoders.py
official/nlp/configs/encoders.py
+10
-5
official/nlp/data/create_pretraining_data.py
official/nlp/data/create_pretraining_data.py
+233
-58
official/nlp/data/sentence_prediction_dataloader.py
official/nlp/data/sentence_prediction_dataloader.py
+6
-1
official/nlp/modeling/models/bert_classifier.py
official/nlp/modeling/models/bert_classifier.py
+3
-0
official/nlp/modeling/models/bert_pretrainer.py
official/nlp/modeling/models/bert_pretrainer.py
+3
-0
official/nlp/modeling/models/bert_span_labeler.py
official/nlp/modeling/models/bert_span_labeler.py
+4
-1
official/nlp/modeling/models/bert_token_classifier.py
official/nlp/modeling/models/bert_token_classifier.py
+3
-0
official/nlp/modeling/models/electra_pretrainer.py
official/nlp/modeling/models/electra_pretrainer.py
+17
-6
official/nlp/modeling/models/electra_pretrainer_test.py
official/nlp/modeling/models/electra_pretrainer_test.py
+0
-3
official/nlp/modeling/networks/albert_transformer_encoder.py
official/nlp/modeling/networks/albert_transformer_encoder.py
+2
-0
official/nlp/modeling/networks/classification.py
official/nlp/modeling/networks/classification.py
+3
-0
official/nlp/modeling/networks/encoder_scaffold.py
official/nlp/modeling/networks/encoder_scaffold.py
+3
-0
official/nlp/modeling/networks/span_labeling.py
official/nlp/modeling/networks/span_labeling.py
+2
-0
official/nlp/modeling/networks/token_classification.py
official/nlp/modeling/networks/token_classification.py
+2
-0
official/nlp/modeling/networks/transformer_encoder.py
official/nlp/modeling/networks/transformer_encoder.py
+3
-0
No files found.
community/README.md
View file @
8f5f819f
...
@@ -20,6 +20,14 @@ This repository provides a curated list of the GitHub repositories with machine
...
@@ -20,6 +20,14 @@ This repository provides a curated list of the GitHub repositories with machine
|
[
ResNet 50
](
https://github.com/IntelAI/models/tree/master/benchmarks/image_recognition/tensorflow/resnet50
)
|
[
Deep Residual Learning for Image Recognition
](
https://arxiv.org/pdf/1512.03385
)
| • Int8 Inference
<br/>
• FP32 Inference |
[
Intel
](
https://github.com/IntelAI
)
|
|
[
ResNet 50
](
https://github.com/IntelAI/models/tree/master/benchmarks/image_recognition/tensorflow/resnet50
)
|
[
Deep Residual Learning for Image Recognition
](
https://arxiv.org/pdf/1512.03385
)
| • Int8 Inference
<br/>
• FP32 Inference |
[
Intel
](
https://github.com/IntelAI
)
|
|
[
ResNet 50v1.5
](
https://github.com/IntelAI/models/tree/master/benchmarks/image_recognition/tensorflow/resnet50v1_5
)
|
[
Deep Residual Learning for Image Recognition
](
https://arxiv.org/pdf/1512.03385
)
| • Int8 Inference
<br/>
• FP32 Inference
<br/>
• FP32 Training |
[
Intel
](
https://github.com/IntelAI
)
|
|
[
ResNet 50v1.5
](
https://github.com/IntelAI/models/tree/master/benchmarks/image_recognition/tensorflow/resnet50v1_5
)
|
[
Deep Residual Learning for Image Recognition
](
https://arxiv.org/pdf/1512.03385
)
| • Int8 Inference
<br/>
• FP32 Inference
<br/>
• FP32 Training |
[
Intel
](
https://github.com/IntelAI
)
|
### Object Detection
| Model | Paper | Features | Maintainer |
|-------|-------|----------|------------|
|
[
R-FCN
](
https://github.com/IntelAI/models/tree/master/benchmarks/object_detection/tensorflow/rfcn
)
|
[
R-FCN: Object Detection<br/>via Region-based Fully Convolutional Networks
](
https://arxiv.org/pdf/1605.06409
)
| • Int8 Inference
<br/>
• FP32 Inference |
[
Intel
](
https://github.com/IntelAI
)
|
|
[
SSD-MobileNet
](
https://github.com/IntelAI/models/tree/master/benchmarks/object_detection/tensorflow/ssd-mobilenet
)
|
[
MobileNets: Efficient Convolutional Neural Networks<br/>for Mobile Vision Applications
](
https://arxiv.org/pdf/1704.04861
)
| • Int8 Inference
<br/>
• FP32 Inference |
[
Intel
](
https://github.com/IntelAI
)
|
|
[
SSD-ResNet34
](
https://github.com/IntelAI/models/tree/master/benchmarks/object_detection/tensorflow/ssd-resnet34
)
|
[
SSD: Single Shot MultiBox Detector
](
https://arxiv.org/pdf/1512.02325
)
| • Int8 Inference
<br/>
• FP32 Inference
<br/>
• FP32 Training |
[
Intel
](
https://github.com/IntelAI
)
|
### Segmentation
### Segmentation
| Model | Paper | Features | Maintainer |
| Model | Paper | Features | Maintainer |
...
...
official/README.md
View file @
8f5f819f
...
@@ -17,12 +17,9 @@ with the same or improved speed and performance with each new TensorFlow build.
...
@@ -17,12 +17,9 @@ with the same or improved speed and performance with each new TensorFlow build.
The team is actively developing new models.
The team is actively developing new models.
In the near future, we will add:
In the near future, we will add:
*
State-of-the-art language understanding models:
*
State-of-the-art language understanding models.
More members in Transformer family
*
State-of-the-art image classification models.
*
State-of-the-art image classification models:
*
State-of-the-art objection detection and instance segmentation models.
EfficientNet, MnasNet, and variants
*
State-of-the-art objection detection and instance segmentation models:
RetinaNet, Mask R-CNN, SpineNet, and variants
## Table of Contents
## Table of Contents
...
...
official/benchmark/resnet_ctl_imagenet_benchmark.py
View file @
8f5f819f
...
@@ -38,13 +38,18 @@ FLAGS = flags.FLAGS
...
@@ -38,13 +38,18 @@ FLAGS = flags.FLAGS
class
CtlBenchmark
(
PerfZeroBenchmark
):
class
CtlBenchmark
(
PerfZeroBenchmark
):
"""Base benchmark class with methods to simplify testing."""
"""Base benchmark class with methods to simplify testing."""
def
__init__
(
self
,
output_dir
=
None
,
default_flags
=
None
,
flag_methods
=
None
):
def
__init__
(
self
,
output_dir
=
None
,
default_flags
=
None
,
flag_methods
=
None
,
**
kwargs
):
self
.
default_flags
=
default_flags
or
{}
self
.
default_flags
=
default_flags
or
{}
self
.
flag_methods
=
flag_methods
or
{}
self
.
flag_methods
=
flag_methods
or
{}
super
(
CtlBenchmark
,
self
).
__init__
(
super
(
CtlBenchmark
,
self
).
__init__
(
output_dir
=
output_dir
,
output_dir
=
output_dir
,
default_flags
=
self
.
default_flags
,
default_flags
=
self
.
default_flags
,
flag_methods
=
self
.
flag_methods
)
flag_methods
=
self
.
flag_methods
,
**
kwargs
)
def
_report_benchmark
(
self
,
def
_report_benchmark
(
self
,
stats
,
stats
,
...
@@ -190,13 +195,14 @@ class Resnet50CtlAccuracy(CtlBenchmark):
...
@@ -190,13 +195,14 @@ class Resnet50CtlAccuracy(CtlBenchmark):
class
Resnet50CtlBenchmarkBase
(
CtlBenchmark
):
class
Resnet50CtlBenchmarkBase
(
CtlBenchmark
):
"""Resnet50 benchmarks."""
"""Resnet50 benchmarks."""
def
__init__
(
self
,
output_dir
=
None
,
default_flags
=
None
):
def
__init__
(
self
,
output_dir
=
None
,
default_flags
=
None
,
**
kwargs
):
flag_methods
=
[
common
.
define_keras_flags
]
flag_methods
=
[
common
.
define_keras_flags
]
super
(
Resnet50CtlBenchmarkBase
,
self
).
__init__
(
super
(
Resnet50CtlBenchmarkBase
,
self
).
__init__
(
output_dir
=
output_dir
,
output_dir
=
output_dir
,
flag_methods
=
flag_methods
,
flag_methods
=
flag_methods
,
default_flags
=
default_flags
)
default_flags
=
default_flags
,
**
kwargs
)
@
benchmark_wrappers
.
enable_runtime_flags
@
benchmark_wrappers
.
enable_runtime_flags
def
_run_and_report_benchmark
(
self
):
def
_run_and_report_benchmark
(
self
):
...
@@ -381,12 +387,24 @@ class Resnet50CtlBenchmarkBase(CtlBenchmark):
...
@@ -381,12 +387,24 @@ class Resnet50CtlBenchmarkBase(CtlBenchmark):
FLAGS
.
single_l2_loss_op
=
True
FLAGS
.
single_l2_loss_op
=
True
FLAGS
.
use_tf_function
=
True
FLAGS
.
use_tf_function
=
True
FLAGS
.
enable_checkpoint_and_export
=
False
FLAGS
.
enable_checkpoint_and_export
=
False
FLAGS
.
data_dir
=
'gs://mlcompass-data/imagenet/imagenet-2012-tfrecord'
def
benchmark_2x2_tpu_bf16
(
self
):
def
benchmark_2x2_tpu_bf16
(
self
):
self
.
_setup
()
self
.
_setup
()
self
.
_set_df_common
()
self
.
_set_df_common
()
FLAGS
.
batch_size
=
1024
FLAGS
.
batch_size
=
1024
FLAGS
.
dtype
=
'bf16'
FLAGS
.
dtype
=
'bf16'
FLAGS
.
model_dir
=
self
.
_get_model_dir
(
'benchmark_2x2_tpu_bf16'
)
self
.
_run_and_report_benchmark
()
@
owner_utils
.
Owner
(
'tf-graph-compiler'
)
def
benchmark_2x2_tpu_bf16_mlir
(
self
):
self
.
_setup
()
self
.
_set_df_common
()
FLAGS
.
batch_size
=
1024
FLAGS
.
dtype
=
'bf16'
tf
.
config
.
experimental
.
enable_mlir_bridge
()
FLAGS
.
model_dir
=
self
.
_get_model_dir
(
'benchmark_2x2_tpu_bf16_mlir'
)
self
.
_run_and_report_benchmark
()
self
.
_run_and_report_benchmark
()
def
benchmark_4x4_tpu_bf16
(
self
):
def
benchmark_4x4_tpu_bf16
(
self
):
...
@@ -394,6 +412,7 @@ class Resnet50CtlBenchmarkBase(CtlBenchmark):
...
@@ -394,6 +412,7 @@ class Resnet50CtlBenchmarkBase(CtlBenchmark):
self
.
_set_df_common
()
self
.
_set_df_common
()
FLAGS
.
batch_size
=
4096
FLAGS
.
batch_size
=
4096
FLAGS
.
dtype
=
'bf16'
FLAGS
.
dtype
=
'bf16'
FLAGS
.
model_dir
=
self
.
_get_model_dir
(
'benchmark_4x4_tpu_bf16'
)
self
.
_run_and_report_benchmark
()
self
.
_run_and_report_benchmark
()
@
owner_utils
.
Owner
(
'tf-graph-compiler'
)
@
owner_utils
.
Owner
(
'tf-graph-compiler'
)
...
@@ -403,6 +422,7 @@ class Resnet50CtlBenchmarkBase(CtlBenchmark):
...
@@ -403,6 +422,7 @@ class Resnet50CtlBenchmarkBase(CtlBenchmark):
self
.
_set_df_common
()
self
.
_set_df_common
()
FLAGS
.
batch_size
=
4096
FLAGS
.
batch_size
=
4096
FLAGS
.
dtype
=
'bf16'
FLAGS
.
dtype
=
'bf16'
FLAGS
.
model_dir
=
self
.
_get_model_dir
(
'benchmark_4x4_tpu_bf16_mlir'
)
tf
.
config
.
experimental
.
enable_mlir_bridge
()
tf
.
config
.
experimental
.
enable_mlir_bridge
()
self
.
_run_and_report_benchmark
()
self
.
_run_and_report_benchmark
()
...
@@ -426,11 +446,11 @@ class Resnet50CtlBenchmarkSynth(Resnet50CtlBenchmarkBase):
...
@@ -426,11 +446,11 @@ class Resnet50CtlBenchmarkSynth(Resnet50CtlBenchmarkBase):
def_flags
[
'skip_eval'
]
=
True
def_flags
[
'skip_eval'
]
=
True
def_flags
[
'use_synthetic_data'
]
=
True
def_flags
[
'use_synthetic_data'
]
=
True
def_flags
[
'train_steps'
]
=
110
def_flags
[
'train_steps'
]
=
110
def_flags
[
'steps_per_loop'
]
=
2
0
def_flags
[
'steps_per_loop'
]
=
1
0
def_flags
[
'log_steps'
]
=
10
def_flags
[
'log_steps'
]
=
10
super
(
Resnet50CtlBenchmarkSynth
,
self
).
__init__
(
super
(
Resnet50CtlBenchmarkSynth
,
self
).
__init__
(
output_dir
=
output_dir
,
default_flags
=
def_flags
)
output_dir
=
output_dir
,
default_flags
=
def_flags
,
**
kwargs
)
class
Resnet50CtlBenchmarkReal
(
Resnet50CtlBenchmarkBase
):
class
Resnet50CtlBenchmarkReal
(
Resnet50CtlBenchmarkBase
):
...
@@ -441,11 +461,11 @@ class Resnet50CtlBenchmarkReal(Resnet50CtlBenchmarkBase):
...
@@ -441,11 +461,11 @@ class Resnet50CtlBenchmarkReal(Resnet50CtlBenchmarkBase):
def_flags
[
'skip_eval'
]
=
True
def_flags
[
'skip_eval'
]
=
True
def_flags
[
'data_dir'
]
=
os
.
path
.
join
(
root_data_dir
,
'imagenet'
)
def_flags
[
'data_dir'
]
=
os
.
path
.
join
(
root_data_dir
,
'imagenet'
)
def_flags
[
'train_steps'
]
=
110
def_flags
[
'train_steps'
]
=
110
def_flags
[
'steps_per_loop'
]
=
2
0
def_flags
[
'steps_per_loop'
]
=
1
0
def_flags
[
'log_steps'
]
=
10
def_flags
[
'log_steps'
]
=
10
super
(
Resnet50CtlBenchmarkReal
,
self
).
__init__
(
super
(
Resnet50CtlBenchmarkReal
,
self
).
__init__
(
output_dir
=
output_dir
,
default_flags
=
def_flags
)
output_dir
=
output_dir
,
default_flags
=
def_flags
,
**
kwargs
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
...
...
official/modeling/hyperparams/config_definitions.py
View file @
8f5f819f
...
@@ -14,6 +14,7 @@
...
@@ -14,6 +14,7 @@
# limitations under the License.
# limitations under the License.
# ==============================================================================
# ==============================================================================
"""Common configuration settings."""
"""Common configuration settings."""
from
typing
import
Optional
,
Union
from
typing
import
Optional
,
Union
import
dataclasses
import
dataclasses
...
@@ -111,6 +112,8 @@ class RuntimeConfig(base_config.Config):
...
@@ -111,6 +112,8 @@ class RuntimeConfig(base_config.Config):
run_eagerly: Whether or not to run the experiment eagerly.
run_eagerly: Whether or not to run the experiment eagerly.
batchnorm_spatial_persistent: Whether or not to enable the spatial
batchnorm_spatial_persistent: Whether or not to enable the spatial
persistent mode for CuDNN batch norm kernel for improved GPU performance.
persistent mode for CuDNN batch norm kernel for improved GPU performance.
allow_tpu_summary: Whether to allow summary happen inside the XLA program
runs on TPU through automatic outside compilation.
"""
"""
distribution_strategy
:
str
=
"mirrored"
distribution_strategy
:
str
=
"mirrored"
enable_xla
:
bool
=
False
enable_xla
:
bool
=
False
...
@@ -123,8 +126,8 @@ class RuntimeConfig(base_config.Config):
...
@@ -123,8 +126,8 @@ class RuntimeConfig(base_config.Config):
task_index
:
int
=
-
1
task_index
:
int
=
-
1
all_reduce_alg
:
Optional
[
str
]
=
None
all_reduce_alg
:
Optional
[
str
]
=
None
num_packs
:
int
=
1
num_packs
:
int
=
1
loss_scale
:
Optional
[
Union
[
str
,
float
]]
=
None
mixed_precision_dtype
:
Optional
[
str
]
=
None
mixed_precision_dtype
:
Optional
[
str
]
=
None
loss_scale
:
Optional
[
Union
[
str
,
float
]]
=
None
run_eagerly
:
bool
=
False
run_eagerly
:
bool
=
False
batchnorm_spatial_persistent
:
bool
=
False
batchnorm_spatial_persistent
:
bool
=
False
...
@@ -172,23 +175,32 @@ class TrainerConfig(base_config.Config):
...
@@ -172,23 +175,32 @@ class TrainerConfig(base_config.Config):
eval_tf_function: whether or not to use tf_function for eval.
eval_tf_function: whether or not to use tf_function for eval.
steps_per_loop: number of steps per loop.
steps_per_loop: number of steps per loop.
summary_interval: number of steps between each summary.
summary_interval: number of steps between each summary.
checkpoint_interval
s
: number of steps between checkpoints.
checkpoint_interval: number of steps between checkpoints.
max_to_keep: max checkpoints to keep.
max_to_keep: max checkpoints to keep.
continuous_eval_timeout: maximum number of seconds to wait between
continuous_eval_timeout: maximum number of seconds to wait between
checkpoints, if set to None, continuous eval will wait indefinetely.
checkpoints, if set to None, continuous eval will wait indefinitely.
train_steps: number of train steps.
validation_steps: number of eval steps. If `None`, the entire eval dataset
is used.
validation_interval: number of training steps to run between evaluations.
"""
"""
optimizer_config
:
OptimizationConfig
=
OptimizationConfig
()
optimizer_config
:
OptimizationConfig
=
OptimizationConfig
()
train_steps
:
int
=
0
# Orbit settings.
validation_steps
:
Optional
[
int
]
=
None
train_tf_while_loop
:
bool
=
True
validation_interval
:
int
=
1000
train_tf_function
:
bool
=
True
eval_tf_function
:
bool
=
True
allow_tpu_summary
:
bool
=
False
# Trainer intervals.
steps_per_loop
:
int
=
1000
steps_per_loop
:
int
=
1000
summary_interval
:
int
=
1000
summary_interval
:
int
=
1000
checkpoint_interval
:
int
=
1000
checkpoint_interval
:
int
=
1000
# Checkpoint manager.
max_to_keep
:
int
=
5
max_to_keep
:
int
=
5
continuous_eval_timeout
:
Optional
[
int
]
=
None
continuous_eval_timeout
:
Optional
[
int
]
=
None
train_tf_while_loop
:
bool
=
True
# Train/Eval routines.
train_tf_function
:
bool
=
True
train_steps
:
int
=
0
eval_tf_function
:
bool
=
True
validation_steps
:
Optional
[
int
]
=
None
validation_interval
:
int
=
1000
@
dataclasses
.
dataclass
@
dataclasses
.
dataclass
...
...
official/nlp/configs/electra.py
View file @
8f5f819f
...
@@ -34,6 +34,8 @@ class ELECTRAPretrainerConfig(base_config.Config):
...
@@ -34,6 +34,8 @@ class ELECTRAPretrainerConfig(base_config.Config):
sequence_length
:
int
=
512
sequence_length
:
int
=
512
num_classes
:
int
=
2
num_classes
:
int
=
2
discriminator_loss_weight
:
float
=
50.0
discriminator_loss_weight
:
float
=
50.0
tie_embeddings
:
bool
=
True
disallow_correct
:
bool
=
False
generator_encoder
:
encoders
.
TransformerEncoderConfig
=
(
generator_encoder
:
encoders
.
TransformerEncoderConfig
=
(
encoders
.
TransformerEncoderConfig
())
encoders
.
TransformerEncoderConfig
())
discriminator_encoder
:
encoders
.
TransformerEncoderConfig
=
(
discriminator_encoder
:
encoders
.
TransformerEncoderConfig
=
(
...
@@ -60,23 +62,30 @@ def instantiate_pretrainer_from_cfg(
...
@@ -60,23 +62,30 @@ def instantiate_pretrainer_from_cfg(
"""Instantiates ElectraPretrainer from the config."""
"""Instantiates ElectraPretrainer from the config."""
generator_encoder_cfg
=
config
.
generator_encoder
generator_encoder_cfg
=
config
.
generator_encoder
discriminator_encoder_cfg
=
config
.
discriminator_encoder
discriminator_encoder_cfg
=
config
.
discriminator_encoder
if
generator_network
is
None
:
# Copy discriminator's embeddings to generator for easier model serialization.
generator_network
=
encoders
.
instantiate_encoder_from_cfg
(
generator_encoder_cfg
)
if
discriminator_network
is
None
:
if
discriminator_network
is
None
:
discriminator_network
=
encoders
.
instantiate_encoder_from_cfg
(
discriminator_network
=
encoders
.
instantiate_encoder_from_cfg
(
discriminator_encoder_cfg
)
discriminator_encoder_cfg
)
if
generator_network
is
None
:
if
config
.
tie_embeddings
:
embedding_layer
=
discriminator_network
.
get_embedding_layer
()
generator_network
=
encoders
.
instantiate_encoder_from_cfg
(
generator_encoder_cfg
,
embedding_layer
=
embedding_layer
)
else
:
generator_network
=
encoders
.
instantiate_encoder_from_cfg
(
generator_encoder_cfg
)
return
electra_pretrainer
.
ElectraPretrainer
(
return
electra_pretrainer
.
ElectraPretrainer
(
generator_network
=
generator_network
,
generator_network
=
generator_network
,
discriminator_network
=
discriminator_network
,
discriminator_network
=
discriminator_network
,
vocab_size
=
config
.
generator_encoder
.
vocab_size
,
vocab_size
=
config
.
generator_encoder
.
vocab_size
,
num_classes
=
config
.
num_classes
,
num_classes
=
config
.
num_classes
,
sequence_length
=
config
.
sequence_length
,
sequence_length
=
config
.
sequence_length
,
last_hidden_dim
=
config
.
generator_encoder
.
hidden_size
,
num_token_predictions
=
config
.
num_masked_tokens
,
num_token_predictions
=
config
.
num_masked_tokens
,
mlm_activation
=
tf_utils
.
get_activation
(
mlm_activation
=
tf_utils
.
get_activation
(
generator_encoder_cfg
.
hidden_activation
),
generator_encoder_cfg
.
hidden_activation
),
mlm_initializer
=
tf
.
keras
.
initializers
.
TruncatedNormal
(
mlm_initializer
=
tf
.
keras
.
initializers
.
TruncatedNormal
(
stddev
=
generator_encoder_cfg
.
initializer_range
),
stddev
=
generator_encoder_cfg
.
initializer_range
),
classification_heads
=
instantiate_classification_heads_from_cfgs
(
classification_heads
=
instantiate_classification_heads_from_cfgs
(
config
.
cls_heads
))
config
.
cls_heads
),
disallow_correct
=
config
.
disallow_correct
)
official/nlp/configs/encoders.py
View file @
8f5f819f
...
@@ -17,12 +17,13 @@
...
@@ -17,12 +17,13 @@
Includes configurations and instantiation methods.
Includes configurations and instantiation methods.
"""
"""
from
typing
import
Optional
import
dataclasses
import
dataclasses
import
gin
import
tensorflow
as
tf
import
tensorflow
as
tf
from
official.modeling
import
tf_utils
from
official.modeling
import
tf_utils
from
official.modeling.hyperparams
import
base_config
from
official.modeling.hyperparams
import
base_config
from
official.nlp.modeling
import
layers
from
official.nlp.modeling
import
networks
from
official.nlp.modeling
import
networks
...
@@ -40,11 +41,13 @@ class TransformerEncoderConfig(base_config.Config):
...
@@ -40,11 +41,13 @@ class TransformerEncoderConfig(base_config.Config):
max_position_embeddings
:
int
=
512
max_position_embeddings
:
int
=
512
type_vocab_size
:
int
=
2
type_vocab_size
:
int
=
2
initializer_range
:
float
=
0.02
initializer_range
:
float
=
0.02
embedding_size
:
Optional
[
int
]
=
None
@
gin
.
configurable
def
instantiate_encoder_from_cfg
(
def
instantiate_encoder_from_cfg
(
config
:
TransformerEncoderConfig
,
config
:
TransformerEncoderConfig
,
encoder_cls
=
networks
.
TransformerEncoder
):
encoder_cls
=
networks
.
TransformerEncoder
,
embedding_layer
:
Optional
[
layers
.
OnDeviceEmbedding
]
=
None
):
"""Instantiate a Transformer encoder network from TransformerEncoderConfig."""
"""Instantiate a Transformer encoder network from TransformerEncoderConfig."""
if
encoder_cls
.
__name__
==
"EncoderScaffold"
:
if
encoder_cls
.
__name__
==
"EncoderScaffold"
:
embedding_cfg
=
dict
(
embedding_cfg
=
dict
(
...
@@ -91,5 +94,7 @@ def instantiate_encoder_from_cfg(config: TransformerEncoderConfig,
...
@@ -91,5 +94,7 @@ def instantiate_encoder_from_cfg(config: TransformerEncoderConfig,
max_sequence_length
=
config
.
max_position_embeddings
,
max_sequence_length
=
config
.
max_position_embeddings
,
type_vocab_size
=
config
.
type_vocab_size
,
type_vocab_size
=
config
.
type_vocab_size
,
initializer
=
tf
.
keras
.
initializers
.
TruncatedNormal
(
initializer
=
tf
.
keras
.
initializers
.
TruncatedNormal
(
stddev
=
config
.
initializer_range
))
stddev
=
config
.
initializer_range
),
embedding_width
=
config
.
embedding_size
,
embedding_layer
=
embedding_layer
)
return
encoder_network
return
encoder_network
official/nlp/data/create_pretraining_data.py
View file @
8f5f819f
...
@@ -18,6 +18,7 @@ from __future__ import division
...
@@ -18,6 +18,7 @@ from __future__ import division
from
__future__
import
print_function
from
__future__
import
print_function
import
collections
import
collections
import
itertools
import
random
import
random
from
absl
import
app
from
absl
import
app
...
@@ -48,6 +49,12 @@ flags.DEFINE_bool(
...
@@ -48,6 +49,12 @@ flags.DEFINE_bool(
"do_whole_word_mask"
,
False
,
"do_whole_word_mask"
,
False
,
"Whether to use whole word masking rather than per-WordPiece masking."
)
"Whether to use whole word masking rather than per-WordPiece masking."
)
flags
.
DEFINE_integer
(
"max_ngram_size"
,
None
,
"Mask contiguous whole words (n-grams) of up to `max_ngram_size` using a "
"weighting scheme to favor shorter n-grams. "
"Note: `--do_whole_word_mask=True` must also be set when n-gram masking."
)
flags
.
DEFINE_bool
(
flags
.
DEFINE_bool
(
"gzip_compress"
,
False
,
"gzip_compress"
,
False
,
"Whether to use `GZIP` compress option to get compressed TFRecord files."
)
"Whether to use `GZIP` compress option to get compressed TFRecord files."
)
...
@@ -192,7 +199,8 @@ def create_training_instances(input_files,
...
@@ -192,7 +199,8 @@ def create_training_instances(input_files,
masked_lm_prob
,
masked_lm_prob
,
max_predictions_per_seq
,
max_predictions_per_seq
,
rng
,
rng
,
do_whole_word_mask
=
False
):
do_whole_word_mask
=
False
,
max_ngram_size
=
None
):
"""Create `TrainingInstance`s from raw text."""
"""Create `TrainingInstance`s from raw text."""
all_documents
=
[[]]
all_documents
=
[[]]
...
@@ -229,7 +237,7 @@ def create_training_instances(input_files,
...
@@ -229,7 +237,7 @@ def create_training_instances(input_files,
create_instances_from_document
(
create_instances_from_document
(
all_documents
,
document_index
,
max_seq_length
,
short_seq_prob
,
all_documents
,
document_index
,
max_seq_length
,
short_seq_prob
,
masked_lm_prob
,
max_predictions_per_seq
,
vocab_words
,
rng
,
masked_lm_prob
,
max_predictions_per_seq
,
vocab_words
,
rng
,
do_whole_word_mask
))
do_whole_word_mask
,
max_ngram_size
))
rng
.
shuffle
(
instances
)
rng
.
shuffle
(
instances
)
return
instances
return
instances
...
@@ -238,7 +246,8 @@ def create_training_instances(input_files,
...
@@ -238,7 +246,8 @@ def create_training_instances(input_files,
def
create_instances_from_document
(
def
create_instances_from_document
(
all_documents
,
document_index
,
max_seq_length
,
short_seq_prob
,
all_documents
,
document_index
,
max_seq_length
,
short_seq_prob
,
masked_lm_prob
,
max_predictions_per_seq
,
vocab_words
,
rng
,
masked_lm_prob
,
max_predictions_per_seq
,
vocab_words
,
rng
,
do_whole_word_mask
=
False
):
do_whole_word_mask
=
False
,
max_ngram_size
=
None
):
"""Creates `TrainingInstance`s for a single document."""
"""Creates `TrainingInstance`s for a single document."""
document
=
all_documents
[
document_index
]
document
=
all_documents
[
document_index
]
...
@@ -337,7 +346,7 @@ def create_instances_from_document(
...
@@ -337,7 +346,7 @@ def create_instances_from_document(
(
tokens
,
masked_lm_positions
,
(
tokens
,
masked_lm_positions
,
masked_lm_labels
)
=
create_masked_lm_predictions
(
masked_lm_labels
)
=
create_masked_lm_predictions
(
tokens
,
masked_lm_prob
,
max_predictions_per_seq
,
vocab_words
,
rng
,
tokens
,
masked_lm_prob
,
max_predictions_per_seq
,
vocab_words
,
rng
,
do_whole_word_mask
)
do_whole_word_mask
,
max_ngram_size
)
instance
=
TrainingInstance
(
instance
=
TrainingInstance
(
tokens
=
tokens
,
tokens
=
tokens
,
segment_ids
=
segment_ids
,
segment_ids
=
segment_ids
,
...
@@ -355,72 +364,238 @@ def create_instances_from_document(
...
@@ -355,72 +364,238 @@ def create_instances_from_document(
MaskedLmInstance
=
collections
.
namedtuple
(
"MaskedLmInstance"
,
MaskedLmInstance
=
collections
.
namedtuple
(
"MaskedLmInstance"
,
[
"index"
,
"label"
])
[
"index"
,
"label"
])
# A _Gram is a [half-open) interval of token indices which form a word.
# E.g.,
# words: ["The", "doghouse"]
# tokens: ["The", "dog", "##house"]
# grams: [(0,1), (1,3)]
_Gram
=
collections
.
namedtuple
(
"_Gram"
,
[
"begin"
,
"end"
])
def
_window
(
iterable
,
size
):
"""Helper to create a sliding window iterator with a given size.
E.g.,
input = [1, 2, 3, 4]
_window(input, 1) => [1], [2], [3], [4]
_window(input, 2) => [1, 2], [2, 3], [3, 4]
_window(input, 3) => [1, 2, 3], [2, 3, 4]
_window(input, 4) => [1, 2, 3, 4]
_window(input, 5) => None
Arguments:
iterable: elements to iterate over.
size: size of the window.
Yields:
Elements of `iterable` batched into a sliding window of length `size`.
"""
i
=
iter
(
iterable
)
window
=
[]
try
:
for
e
in
range
(
0
,
size
):
window
.
append
(
next
(
i
))
yield
window
except
StopIteration
:
# handle the case where iterable's length is less than the window size.
return
for
e
in
i
:
window
=
window
[
1
:]
+
[
e
]
yield
window
def
_contiguous
(
sorted_grams
):
"""Test whether a sequence of grams is contiguous.
Arguments:
sorted_grams: _Grams which are sorted in increasing order.
Returns:
True if `sorted_grams` are touching each other.
E.g.,
_contiguous([(1, 4), (4, 5), (5, 10)]) == True
_contiguous([(1, 2), (4, 5)]) == False
"""
for
a
,
b
in
_window
(
sorted_grams
,
2
):
if
a
.
end
!=
b
.
begin
:
return
False
return
True
def
_masking_ngrams
(
grams
,
max_ngram_size
,
max_masked_tokens
,
rng
):
"""Create a list of masking {1, ..., n}-grams from a list of one-grams.
This is an extention of 'whole word masking' to mask multiple, contiguous
words such as (e.g., "the red boat").
Each input gram represents the token indices of a single word,
words: ["the", "red", "boat"]
tokens: ["the", "red", "boa", "##t"]
grams: [(0,1), (1,2), (2,4)]
For a `max_ngram_size` of three, possible outputs masks include:
1-grams: (0,1), (1,2), (2,4)
2-grams: (0,2), (1,4)
3-grams; (0,4)
Output masks will not overlap and contain less than `max_masked_tokens` total
tokens. E.g., for the example above with `max_masked_tokens` as three,
valid outputs are,
[(0,1), (1,2)] # "the", "red" covering two tokens
[(1,2), (2,4)] # "red", "boa", "##t" covering three tokens
The length of the selected n-gram follows a zipf weighting to
favor shorter n-gram sizes (weight(1)=1, weight(2)=1/2, weight(3)=1/3, ...).
Arguments:
grams: List of one-grams.
max_ngram_size: Maximum number of contiguous one-grams combined to create
an n-gram.
max_masked_tokens: Maximum total number of tokens to be masked.
rng: `random.Random` generator.
Returns:
A list of n-grams to be used as masks.
"""
if
not
grams
:
return
None
grams
=
sorted
(
grams
)
num_tokens
=
grams
[
-
1
].
end
# Ensure our grams are valid (i.e., they don't overlap).
for
a
,
b
in
_window
(
grams
,
2
):
if
a
.
end
>
b
.
begin
:
raise
ValueError
(
"overlapping grams: {}"
.
format
(
grams
))
# Build map from n-gram length to list of n-grams.
ngrams
=
{
i
:
[]
for
i
in
range
(
1
,
max_ngram_size
+
1
)}
for
gram_size
in
range
(
1
,
max_ngram_size
+
1
):
for
g
in
_window
(
grams
,
gram_size
):
if
_contiguous
(
g
):
# Add an n-gram which spans these one-grams.
ngrams
[
gram_size
].
append
(
_Gram
(
g
[
0
].
begin
,
g
[
-
1
].
end
))
# Shuffle each list of n-grams.
for
v
in
ngrams
.
values
():
rng
.
shuffle
(
v
)
# Create the weighting for n-gram length selection.
# Stored cummulatively for `random.choices` below.
cummulative_weights
=
list
(
itertools
.
accumulate
([
1.
/
n
for
n
in
range
(
1
,
max_ngram_size
+
1
)]))
output_ngrams
=
[]
# Keep a bitmask of which tokens have been masked.
masked_tokens
=
[
False
]
*
num_tokens
# Loop until we have enough masked tokens or there are no more candidate
# n-grams of any length.
# Each code path should ensure one or more elements from `ngrams` are removed
# to guarentee this loop terminates.
while
(
sum
(
masked_tokens
)
<
max_masked_tokens
and
sum
(
len
(
s
)
for
s
in
ngrams
.
values
())):
# Pick an n-gram size based on our weights.
sz
=
random
.
choices
(
range
(
1
,
max_ngram_size
+
1
),
cum_weights
=
cummulative_weights
)[
0
]
# Ensure this size doesn't result in too many masked tokens.
# E.g., a two-gram contains _at least_ two tokens.
if
sum
(
masked_tokens
)
+
sz
>
max_masked_tokens
:
# All n-grams of this length are too long and can be removed from
# consideration.
ngrams
[
sz
].
clear
()
continue
def
create_masked_lm_predictions
(
tokens
,
masked_lm_prob
,
# All of the n-grams of this size have been used.
max_predictions_per_seq
,
vocab_words
,
rng
,
if
not
ngrams
[
sz
]:
do_whole_word_mask
):
continue
"""Creates the predictions for the masked LM objective."""
# Choose a random n-gram of the given size.
gram
=
ngrams
[
sz
].
pop
()
num_gram_tokens
=
gram
.
end
-
gram
.
begin
# Check if this would add too many tokens.
if
num_gram_tokens
+
sum
(
masked_tokens
)
>
max_masked_tokens
:
continue
# Check if any of the tokens in this gram have already been masked.
if
sum
(
masked_tokens
[
gram
.
begin
:
gram
.
end
]):
continue
cand_indexes
=
[]
# Found a usable n-gram! Mark its tokens as masked and add it to return.
for
(
i
,
token
)
in
enumerate
(
tokens
):
masked_tokens
[
gram
.
begin
:
gram
.
end
]
=
[
True
]
*
(
gram
.
end
-
gram
.
begin
)
if
token
==
"[CLS]"
or
token
==
"[SEP]"
:
output_ngrams
.
append
(
gram
)
return
output_ngrams
def
_wordpieces_to_grams
(
tokens
):
"""Reconstitue grams (words) from `tokens`.
E.g.,
tokens: ['[CLS]', 'That', 'lit', '##tle', 'blue', 'tru', '##ck', '[SEP]']
grams: [ [1,2), [2, 4), [4,5) , [5, 6)]
Arguments:
tokens: list of wordpieces
Returns:
List of _Grams representing spans of whole words
(without "[CLS]" and "[SEP]").
"""
grams
=
[]
gram_start_pos
=
None
for
i
,
token
in
enumerate
(
tokens
):
if
gram_start_pos
is
not
None
and
token
.
startswith
(
"##"
):
continue
continue
# Whole Word Masking means that if we mask all of the wordpieces
if
gram_start_pos
is
not
None
:
# corresponding to an original word. When a word has been split into
grams
.
append
(
_Gram
(
gram_start_pos
,
i
))
# WordPieces, the first token does not have any marker and any subsequence
if
token
not
in
[
"[CLS]"
,
"[SEP]"
]:
# tokens are prefixed with ##. So whenever we see the ## token, we
gram_start_pos
=
i
# append it to the previous set of word indexes.
#
# Note that Whole Word Masking does *not* change the training code
# at all -- we still predict each WordPiece independently, softmaxed
# over the entire vocabulary.
if
(
do_whole_word_mask
and
len
(
cand_indexes
)
>=
1
and
token
.
startswith
(
"##"
)):
cand_indexes
[
-
1
].
append
(
i
)
else
:
else
:
cand_indexes
.
append
([
i
])
gram_start_pos
=
None
if
gram_start_pos
is
not
None
:
grams
.
append
(
_Gram
(
gram_start_pos
,
len
(
tokens
)))
return
grams
rng
.
shuffle
(
cand_indexes
)
output_tokens
=
list
(
tokens
)
def
create_masked_lm_predictions
(
tokens
,
masked_lm_prob
,
max_predictions_per_seq
,
vocab_words
,
rng
,
do_whole_word_mask
,
max_ngram_size
=
None
):
"""Creates the predictions for the masked LM objective."""
if
do_whole_word_mask
:
grams
=
_wordpieces_to_grams
(
tokens
)
else
:
# Here we consider each token to be a word to allow for sub-word masking.
if
max_ngram_size
:
raise
ValueError
(
"cannot use ngram masking without whole word masking"
)
grams
=
[
_Gram
(
i
,
i
+
1
)
for
i
in
range
(
0
,
len
(
tokens
))
if
tokens
[
i
]
not
in
[
"[CLS]"
,
"[SEP]"
]]
num_to_predict
=
min
(
max_predictions_per_seq
,
num_to_predict
=
min
(
max_predictions_per_seq
,
max
(
1
,
int
(
round
(
len
(
tokens
)
*
masked_lm_prob
))))
max
(
1
,
int
(
round
(
len
(
tokens
)
*
masked_lm_prob
))))
# Generate masks. If `max_ngram_size` in [0, None] it means we're doing
# whole word masking or token level masking. Both of these can be treated
# as the `max_ngram_size=1` case.
masked_grams
=
_masking_ngrams
(
grams
,
max_ngram_size
or
1
,
num_to_predict
,
rng
)
masked_lms
=
[]
masked_lms
=
[]
covered_indexes
=
set
()
output_tokens
=
list
(
tokens
)
for
index_set
in
cand_indexes
:
for
gram
in
masked_grams
:
if
len
(
masked_lms
)
>=
num_to_predict
:
# 80% of the time, replace all n-gram tokens with [MASK]
break
if
rng
.
random
()
<
0.8
:
# If adding a whole-word mask would exceed the maximum number of
replacement_action
=
lambda
idx
:
"[MASK]"
# predictions, then just skip this candidate.
else
:
if
len
(
masked_lms
)
+
len
(
index_set
)
>
num_to_predict
:
# 10% of the time, keep all the original n-gram tokens.
continue
if
rng
.
random
()
<
0.5
:
is_any_index_covered
=
False
replacement_action
=
lambda
idx
:
tokens
[
idx
]
for
index
in
index_set
:
# 10% of the time, replace each n-gram token with a random word.
if
index
in
covered_indexes
:
is_any_index_covered
=
True
break
if
is_any_index_covered
:
continue
for
index
in
index_set
:
covered_indexes
.
add
(
index
)
masked_token
=
None
# 80% of the time, replace with [MASK]
if
rng
.
random
()
<
0.8
:
masked_token
=
"[MASK]"
else
:
else
:
# 10% of the time, keep original
replacement_action
=
lambda
idx
:
rng
.
choice
(
vocab_words
)
if
rng
.
random
()
<
0.5
:
masked_token
=
tokens
[
index
]
# 10% of the time, replace with random word
else
:
masked_token
=
vocab_words
[
rng
.
randint
(
0
,
len
(
vocab_words
)
-
1
)]
output_tokens
[
index
]
=
masked_token
for
idx
in
range
(
gram
.
begin
,
gram
.
end
):
output_tokens
[
idx
]
=
replacement_action
(
idx
)
masked_lms
.
append
(
MaskedLmInstance
(
index
=
idx
,
label
=
tokens
[
idx
]))
masked_lms
.
append
(
MaskedLmInstance
(
index
=
index
,
label
=
tokens
[
index
]))
assert
len
(
masked_lms
)
<=
num_to_predict
assert
len
(
masked_lms
)
<=
num_to_predict
masked_lms
=
sorted
(
masked_lms
,
key
=
lambda
x
:
x
.
index
)
masked_lms
=
sorted
(
masked_lms
,
key
=
lambda
x
:
x
.
index
)
...
@@ -467,7 +642,7 @@ def main(_):
...
@@ -467,7 +642,7 @@ def main(_):
instances
=
create_training_instances
(
instances
=
create_training_instances
(
input_files
,
tokenizer
,
FLAGS
.
max_seq_length
,
FLAGS
.
dupe_factor
,
input_files
,
tokenizer
,
FLAGS
.
max_seq_length
,
FLAGS
.
dupe_factor
,
FLAGS
.
short_seq_prob
,
FLAGS
.
masked_lm_prob
,
FLAGS
.
max_predictions_per_seq
,
FLAGS
.
short_seq_prob
,
FLAGS
.
masked_lm_prob
,
FLAGS
.
max_predictions_per_seq
,
rng
,
FLAGS
.
do_whole_word_mask
)
rng
,
FLAGS
.
do_whole_word_mask
,
FLAGS
.
max_ngram_size
)
output_files
=
FLAGS
.
output_file
.
split
(
","
)
output_files
=
FLAGS
.
output_file
.
split
(
","
)
logging
.
info
(
"*** Writing to output files ***"
)
logging
.
info
(
"*** Writing to output files ***"
)
...
...
official/nlp/data/sentence_prediction_dataloader.py
View file @
8f5f819f
...
@@ -23,6 +23,9 @@ from official.modeling.hyperparams import config_definitions as cfg
...
@@ -23,6 +23,9 @@ from official.modeling.hyperparams import config_definitions as cfg
from
official.nlp.data
import
data_loader_factory
from
official.nlp.data
import
data_loader_factory
LABEL_TYPES_MAP
=
{
'int'
:
tf
.
int64
,
'float'
:
tf
.
float32
}
@
dataclasses
.
dataclass
@
dataclasses
.
dataclass
class
SentencePredictionDataConfig
(
cfg
.
DataConfig
):
class
SentencePredictionDataConfig
(
cfg
.
DataConfig
):
"""Data config for sentence prediction task (tasks/sentence_prediction)."""
"""Data config for sentence prediction task (tasks/sentence_prediction)."""
...
@@ -30,6 +33,7 @@ class SentencePredictionDataConfig(cfg.DataConfig):
...
@@ -30,6 +33,7 @@ class SentencePredictionDataConfig(cfg.DataConfig):
global_batch_size
:
int
=
32
global_batch_size
:
int
=
32
is_training
:
bool
=
True
is_training
:
bool
=
True
seq_length
:
int
=
128
seq_length
:
int
=
128
label_type
:
str
=
'int'
@
data_loader_factory
.
register_data_loader_cls
(
SentencePredictionDataConfig
)
@
data_loader_factory
.
register_data_loader_cls
(
SentencePredictionDataConfig
)
...
@@ -42,11 +46,12 @@ class SentencePredictionDataLoader:
...
@@ -42,11 +46,12 @@ class SentencePredictionDataLoader:
def
_decode
(
self
,
record
:
tf
.
Tensor
):
def
_decode
(
self
,
record
:
tf
.
Tensor
):
"""Decodes a serialized tf.Example."""
"""Decodes a serialized tf.Example."""
label_type
=
LABEL_TYPES_MAP
[
self
.
_params
.
label_type
]
name_to_features
=
{
name_to_features
=
{
'input_ids'
:
tf
.
io
.
FixedLenFeature
([
self
.
_seq_length
],
tf
.
int64
),
'input_ids'
:
tf
.
io
.
FixedLenFeature
([
self
.
_seq_length
],
tf
.
int64
),
'input_mask'
:
tf
.
io
.
FixedLenFeature
([
self
.
_seq_length
],
tf
.
int64
),
'input_mask'
:
tf
.
io
.
FixedLenFeature
([
self
.
_seq_length
],
tf
.
int64
),
'segment_ids'
:
tf
.
io
.
FixedLenFeature
([
self
.
_seq_length
],
tf
.
int64
),
'segment_ids'
:
tf
.
io
.
FixedLenFeature
([
self
.
_seq_length
],
tf
.
int64
),
'label_ids'
:
tf
.
io
.
FixedLenFeature
([],
tf
.
int64
),
'label_ids'
:
tf
.
io
.
FixedLenFeature
([],
label_type
),
}
}
example
=
tf
.
io
.
parse_single_example
(
record
,
name_to_features
)
example
=
tf
.
io
.
parse_single_example
(
record
,
name_to_features
)
...
...
official/nlp/modeling/models/bert_classifier.py
View file @
8f5f819f
...
@@ -37,6 +37,9 @@ class BertClassifier(tf.keras.Model):
...
@@ -37,6 +37,9 @@ class BertClassifier(tf.keras.Model):
instantiates a classification network based on the passed `num_classes`
instantiates a classification network based on the passed `num_classes`
argument. If `num_classes` is set to 1, a regression network is instantiated.
argument. If `num_classes` is set to 1, a regression network is instantiated.
*Note* that the model is constructed by
[Keras Functional API](https://keras.io/guides/functional_api/).
Arguments:
Arguments:
network: A transformer network. This network should output a sequence output
network: A transformer network. This network should output a sequence output
and a classification output. Furthermore, it should expose its embedding
and a classification output. Furthermore, it should expose its embedding
...
...
official/nlp/modeling/models/bert_pretrainer.py
View file @
8f5f819f
...
@@ -41,6 +41,9 @@ class BertPretrainer(tf.keras.Model):
...
@@ -41,6 +41,9 @@ class BertPretrainer(tf.keras.Model):
instantiates the masked language model and classification networks that are
instantiates the masked language model and classification networks that are
used to create the training objectives.
used to create the training objectives.
*Note* that the model is constructed by
[Keras Functional API](https://keras.io/guides/functional_api/).
Arguments:
Arguments:
network: A transformer network. This network should output a sequence output
network: A transformer network. This network should output a sequence output
and a classification output.
and a classification output.
...
...
official/nlp/modeling/models/bert_span_labeler.py
View file @
8f5f819f
...
@@ -32,9 +32,12 @@ class BertSpanLabeler(tf.keras.Model):
...
@@ -32,9 +32,12 @@ class BertSpanLabeler(tf.keras.Model):
encoder as described in "BERT: Pre-training of Deep Bidirectional Transformers
encoder as described in "BERT: Pre-training of Deep Bidirectional Transformers
for Language Understanding" (https://arxiv.org/abs/1810.04805).
for Language Understanding" (https://arxiv.org/abs/1810.04805).
The BertSpanLabeler allows a user to pass in a transformer
stack
, and
The BertSpanLabeler allows a user to pass in a transformer
encoder
, and
instantiates a span labeling network based on a single dense layer.
instantiates a span labeling network based on a single dense layer.
*Note* that the model is constructed by
[Keras Functional API](https://keras.io/guides/functional_api/).
Arguments:
Arguments:
network: A transformer network. This network should output a sequence output
network: A transformer network. This network should output a sequence output
and a classification output. Furthermore, it should expose its embedding
and a classification output. Furthermore, it should expose its embedding
...
...
official/nlp/modeling/models/bert_token_classifier.py
View file @
8f5f819f
...
@@ -36,6 +36,9 @@ class BertTokenClassifier(tf.keras.Model):
...
@@ -36,6 +36,9 @@ class BertTokenClassifier(tf.keras.Model):
instantiates a token classification network based on the passed `num_classes`
instantiates a token classification network based on the passed `num_classes`
argument.
argument.
*Note* that the model is constructed by
[Keras Functional API](https://keras.io/guides/functional_api/).
Arguments:
Arguments:
network: A transformer network. This network should output a sequence output
network: A transformer network. This network should output a sequence output
and a classification output. Furthermore, it should expose its embedding
and a classification output. Furthermore, it should expose its embedding
...
...
official/nlp/modeling/models/electra_pretrainer.py
View file @
8f5f819f
...
@@ -39,6 +39,9 @@ class ElectraPretrainer(tf.keras.Model):
...
@@ -39,6 +39,9 @@ class ElectraPretrainer(tf.keras.Model):
model (at generator side) and classification networks (at discriminator side)
model (at generator side) and classification networks (at discriminator side)
that are used to create the training objectives.
that are used to create the training objectives.
*Note* that the model is constructed by Keras Subclass API, where layers are
defined inside __init__ and call() implements the computation.
Arguments:
Arguments:
generator_network: A transformer network for generator, this network should
generator_network: A transformer network for generator, this network should
output a sequence output and an optional classification output.
output a sequence output and an optional classification output.
...
@@ -48,7 +51,6 @@ class ElectraPretrainer(tf.keras.Model):
...
@@ -48,7 +51,6 @@ class ElectraPretrainer(tf.keras.Model):
num_classes: Number of classes to predict from the classification network
num_classes: Number of classes to predict from the classification network
for the generator network (not used now)
for the generator network (not used now)
sequence_length: Input sequence length
sequence_length: Input sequence length
last_hidden_dim: Last hidden dim of generator transformer output
num_token_predictions: Number of tokens to predict from the masked LM.
num_token_predictions: Number of tokens to predict from the masked LM.
mlm_activation: The activation (if any) to use in the masked LM and
mlm_activation: The activation (if any) to use in the masked LM and
classification networks. If None, no activation will be used.
classification networks. If None, no activation will be used.
...
@@ -66,7 +68,6 @@ class ElectraPretrainer(tf.keras.Model):
...
@@ -66,7 +68,6 @@ class ElectraPretrainer(tf.keras.Model):
vocab_size
,
vocab_size
,
num_classes
,
num_classes
,
sequence_length
,
sequence_length
,
last_hidden_dim
,
num_token_predictions
,
num_token_predictions
,
mlm_activation
=
None
,
mlm_activation
=
None
,
mlm_initializer
=
'glorot_uniform'
,
mlm_initializer
=
'glorot_uniform'
,
...
@@ -80,7 +81,6 @@ class ElectraPretrainer(tf.keras.Model):
...
@@ -80,7 +81,6 @@ class ElectraPretrainer(tf.keras.Model):
'vocab_size'
:
vocab_size
,
'vocab_size'
:
vocab_size
,
'num_classes'
:
num_classes
,
'num_classes'
:
num_classes
,
'sequence_length'
:
sequence_length
,
'sequence_length'
:
sequence_length
,
'last_hidden_dim'
:
last_hidden_dim
,
'num_token_predictions'
:
num_token_predictions
,
'num_token_predictions'
:
num_token_predictions
,
'mlm_activation'
:
mlm_activation
,
'mlm_activation'
:
mlm_activation
,
'mlm_initializer'
:
mlm_initializer
,
'mlm_initializer'
:
mlm_initializer
,
...
@@ -95,7 +95,6 @@ class ElectraPretrainer(tf.keras.Model):
...
@@ -95,7 +95,6 @@ class ElectraPretrainer(tf.keras.Model):
self
.
vocab_size
=
vocab_size
self
.
vocab_size
=
vocab_size
self
.
num_classes
=
num_classes
self
.
num_classes
=
num_classes
self
.
sequence_length
=
sequence_length
self
.
sequence_length
=
sequence_length
self
.
last_hidden_dim
=
last_hidden_dim
self
.
num_token_predictions
=
num_token_predictions
self
.
num_token_predictions
=
num_token_predictions
self
.
mlm_activation
=
mlm_activation
self
.
mlm_activation
=
mlm_activation
self
.
mlm_initializer
=
mlm_initializer
self
.
mlm_initializer
=
mlm_initializer
...
@@ -108,10 +107,15 @@ class ElectraPretrainer(tf.keras.Model):
...
@@ -108,10 +107,15 @@ class ElectraPretrainer(tf.keras.Model):
output
=
output_type
,
output
=
output_type
,
name
=
'generator_masked_lm'
)
name
=
'generator_masked_lm'
)
self
.
classification
=
layers
.
ClassificationHead
(
self
.
classification
=
layers
.
ClassificationHead
(
inner_dim
=
last_
hidden_
dim
,
inner_dim
=
generator_network
.
_config_dict
[
'
hidden_
size'
]
,
num_classes
=
num_classes
,
num_classes
=
num_classes
,
initializer
=
mlm_initializer
,
initializer
=
mlm_initializer
,
name
=
'generator_classification_head'
)
name
=
'generator_classification_head'
)
self
.
discriminator_projection
=
tf
.
keras
.
layers
.
Dense
(
units
=
discriminator_network
.
_config_dict
[
'hidden_size'
],
activation
=
mlm_activation
,
kernel_initializer
=
mlm_initializer
,
name
=
'discriminator_projection_head'
)
self
.
discriminator_head
=
tf
.
keras
.
layers
.
Dense
(
self
.
discriminator_head
=
tf
.
keras
.
layers
.
Dense
(
units
=
1
,
kernel_initializer
=
mlm_initializer
)
units
=
1
,
kernel_initializer
=
mlm_initializer
)
...
@@ -165,7 +169,8 @@ class ElectraPretrainer(tf.keras.Model):
...
@@ -165,7 +169,8 @@ class ElectraPretrainer(tf.keras.Model):
if
isinstance
(
disc_sequence_output
,
list
):
if
isinstance
(
disc_sequence_output
,
list
):
disc_sequence_output
=
disc_sequence_output
[
-
1
]
disc_sequence_output
=
disc_sequence_output
[
-
1
]
disc_logits
=
self
.
discriminator_head
(
disc_sequence_output
)
disc_logits
=
self
.
discriminator_head
(
self
.
discriminator_projection
(
disc_sequence_output
))
disc_logits
=
tf
.
squeeze
(
disc_logits
,
axis
=-
1
)
disc_logits
=
tf
.
squeeze
(
disc_logits
,
axis
=-
1
)
outputs
=
{
outputs
=
{
...
@@ -214,6 +219,12 @@ class ElectraPretrainer(tf.keras.Model):
...
@@ -214,6 +219,12 @@ class ElectraPretrainer(tf.keras.Model):
'sampled_tokens'
:
sampled_tokens
'sampled_tokens'
:
sampled_tokens
}
}
@
property
def
checkpoint_items
(
self
):
"""Returns a dictionary of items to be additionally checkpointed."""
items
=
dict
(
encoder
=
self
.
discriminator_network
)
return
items
def
get_config
(
self
):
def
get_config
(
self
):
return
self
.
_config
return
self
.
_config
...
...
official/nlp/modeling/models/electra_pretrainer_test.py
View file @
8f5f819f
...
@@ -49,7 +49,6 @@ class ElectraPretrainerTest(keras_parameterized.TestCase):
...
@@ -49,7 +49,6 @@ class ElectraPretrainerTest(keras_parameterized.TestCase):
vocab_size
=
vocab_size
,
vocab_size
=
vocab_size
,
num_classes
=
num_classes
,
num_classes
=
num_classes
,
sequence_length
=
sequence_length
,
sequence_length
=
sequence_length
,
last_hidden_dim
=
768
,
num_token_predictions
=
num_token_predictions
,
num_token_predictions
=
num_token_predictions
,
disallow_correct
=
True
)
disallow_correct
=
True
)
...
@@ -101,7 +100,6 @@ class ElectraPretrainerTest(keras_parameterized.TestCase):
...
@@ -101,7 +100,6 @@ class ElectraPretrainerTest(keras_parameterized.TestCase):
vocab_size
=
100
,
vocab_size
=
100
,
num_classes
=
2
,
num_classes
=
2
,
sequence_length
=
3
,
sequence_length
=
3
,
last_hidden_dim
=
768
,
num_token_predictions
=
2
)
num_token_predictions
=
2
)
# Create a set of 2-dimensional data tensors to feed into the model.
# Create a set of 2-dimensional data tensors to feed into the model.
...
@@ -140,7 +138,6 @@ class ElectraPretrainerTest(keras_parameterized.TestCase):
...
@@ -140,7 +138,6 @@ class ElectraPretrainerTest(keras_parameterized.TestCase):
vocab_size
=
100
,
vocab_size
=
100
,
num_classes
=
2
,
num_classes
=
2
,
sequence_length
=
3
,
sequence_length
=
3
,
last_hidden_dim
=
768
,
num_token_predictions
=
2
)
num_token_predictions
=
2
)
# Create another BERT trainer via serialization and deserialization.
# Create another BERT trainer via serialization and deserialization.
...
...
official/nlp/modeling/networks/albert_transformer_encoder.py
View file @
8f5f819f
...
@@ -40,6 +40,8 @@ class AlbertTransformerEncoder(tf.keras.Model):
...
@@ -40,6 +40,8 @@ class AlbertTransformerEncoder(tf.keras.Model):
The default values for this object are taken from the ALBERT-Base
The default values for this object are taken from the ALBERT-Base
implementation described in the paper.
implementation described in the paper.
*Note* that the network is constructed by Keras Functional API.
Arguments:
Arguments:
vocab_size: The size of the token vocabulary.
vocab_size: The size of the token vocabulary.
embedding_width: The width of the word embeddings. If the embedding width is
embedding_width: The width of the word embeddings. If the embedding width is
...
...
official/nlp/modeling/networks/classification.py
View file @
8f5f819f
...
@@ -29,6 +29,9 @@ class Classification(tf.keras.Model):
...
@@ -29,6 +29,9 @@ class Classification(tf.keras.Model):
This network implements a simple classifier head based on a dense layer. If
This network implements a simple classifier head based on a dense layer. If
num_classes is one, it can be considered as a regression problem.
num_classes is one, it can be considered as a regression problem.
*Note* that the network is constructed by
[Keras Functional API](https://keras.io/guides/functional_api/).
Arguments:
Arguments:
input_width: The innermost dimension of the input tensor to this network.
input_width: The innermost dimension of the input tensor to this network.
num_classes: The number of classes that this network should classify to. If
num_classes: The number of classes that this network should classify to. If
...
...
official/nlp/modeling/networks/encoder_scaffold.py
View file @
8f5f819f
...
@@ -49,6 +49,9 @@ class EncoderScaffold(tf.keras.Model):
...
@@ -49,6 +49,9 @@ class EncoderScaffold(tf.keras.Model):
If the hidden_cls is not overridden, a default transformer layer will be
If the hidden_cls is not overridden, a default transformer layer will be
instantiated.
instantiated.
*Note* that the network is constructed by
[Keras Functional API](https://keras.io/guides/functional_api/).
Arguments:
Arguments:
pooled_output_dim: The dimension of pooled output.
pooled_output_dim: The dimension of pooled output.
pooler_layer_initializer: The initializer for the classification
pooler_layer_initializer: The initializer for the classification
...
...
official/nlp/modeling/networks/span_labeling.py
View file @
8f5f819f
...
@@ -27,6 +27,8 @@ class SpanLabeling(tf.keras.Model):
...
@@ -27,6 +27,8 @@ class SpanLabeling(tf.keras.Model):
"""Span labeling network head for BERT modeling.
"""Span labeling network head for BERT modeling.
This network implements a simple single-span labeler based on a dense layer.
This network implements a simple single-span labeler based on a dense layer.
*Note* that the network is constructed by
[Keras Functional API](https://keras.io/guides/functional_api/).
Arguments:
Arguments:
input_width: The innermost dimension of the input tensor to this network.
input_width: The innermost dimension of the input tensor to this network.
...
...
official/nlp/modeling/networks/token_classification.py
View file @
8f5f819f
...
@@ -27,6 +27,8 @@ class TokenClassification(tf.keras.Model):
...
@@ -27,6 +27,8 @@ class TokenClassification(tf.keras.Model):
"""TokenClassification network head for BERT modeling.
"""TokenClassification network head for BERT modeling.
This network implements a simple token classifier head based on a dense layer.
This network implements a simple token classifier head based on a dense layer.
*Note* that the network is constructed by
[Keras Functional API](https://keras.io/guides/functional_api/).
Arguments:
Arguments:
input_width: The innermost dimension of the input tensor to this network.
input_width: The innermost dimension of the input tensor to this network.
...
...
official/nlp/modeling/networks/transformer_encoder.py
View file @
8f5f819f
...
@@ -39,6 +39,9 @@ class TransformerEncoder(tf.keras.Model):
...
@@ -39,6 +39,9 @@ class TransformerEncoder(tf.keras.Model):
in "BERT: Pre-training of Deep Bidirectional Transformers for Language
in "BERT: Pre-training of Deep Bidirectional Transformers for Language
Understanding".
Understanding".
*Note* that the network is constructed by
[Keras Functional API](https://keras.io/guides/functional_api/).
Arguments:
Arguments:
vocab_size: The size of the token vocabulary.
vocab_size: The size of the token vocabulary.
hidden_size: The size of the transformer hidden layers.
hidden_size: The size of the transformer hidden layers.
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment