Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
e7667f6f
Commit
e7667f6f
authored
Jul 16, 2020
by
Kaushik Shivakumar
Browse files
Merge branch 'master' of
https://github.com/tensorflow/models
into context_tf2
parents
974c463e
709a6617
Changes
26
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
370 additions
and
27 deletions
+370
-27
community/README.md
community/README.md
+8
-0
official/benchmark/resnet_ctl_imagenet_benchmark.py
official/benchmark/resnet_ctl_imagenet_benchmark.py
+17
-6
official/modeling/hyperparams/config_definitions.py
official/modeling/hyperparams/config_definitions.py
+7
-0
official/nlp/configs/electra.py
official/nlp/configs/electra.py
+14
-5
official/nlp/configs/encoders.py
official/nlp/configs/encoders.py
+10
-5
official/nlp/modeling/models/bert_classifier.py
official/nlp/modeling/models/bert_classifier.py
+3
-0
official/nlp/modeling/models/bert_pretrainer.py
official/nlp/modeling/models/bert_pretrainer.py
+3
-0
official/nlp/modeling/models/bert_span_labeler.py
official/nlp/modeling/models/bert_span_labeler.py
+4
-1
official/nlp/modeling/models/bert_token_classifier.py
official/nlp/modeling/models/bert_token_classifier.py
+3
-0
official/nlp/modeling/models/electra_pretrainer.py
official/nlp/modeling/models/electra_pretrainer.py
+17
-6
official/nlp/modeling/models/electra_pretrainer_test.py
official/nlp/modeling/models/electra_pretrainer_test.py
+0
-3
official/nlp/modeling/networks/albert_transformer_encoder.py
official/nlp/modeling/networks/albert_transformer_encoder.py
+2
-0
official/nlp/modeling/networks/classification.py
official/nlp/modeling/networks/classification.py
+3
-0
official/nlp/modeling/networks/encoder_scaffold.py
official/nlp/modeling/networks/encoder_scaffold.py
+3
-0
official/nlp/modeling/networks/span_labeling.py
official/nlp/modeling/networks/span_labeling.py
+2
-0
official/nlp/modeling/networks/token_classification.py
official/nlp/modeling/networks/token_classification.py
+2
-0
official/nlp/modeling/networks/transformer_encoder.py
official/nlp/modeling/networks/transformer_encoder.py
+3
-0
official/nlp/tasks/electra_task.py
official/nlp/tasks/electra_task.py
+209
-0
official/nlp/tasks/electra_task_test.py
official/nlp/tasks/electra_task_test.py
+59
-0
official/vision/image_classification/resnet/resnet_ctl_imagenet_main.py
...n/image_classification/resnet/resnet_ctl_imagenet_main.py
+1
-1
No files found.
community/README.md
View file @
e7667f6f
...
@@ -20,6 +20,14 @@ This repository provides a curated list of the GitHub repositories with machine
...
@@ -20,6 +20,14 @@ This repository provides a curated list of the GitHub repositories with machine
|
[
ResNet 50
](
https://github.com/IntelAI/models/tree/master/benchmarks/image_recognition/tensorflow/resnet50
)
|
[
Deep Residual Learning for Image Recognition
](
https://arxiv.org/pdf/1512.03385
)
| • Int8 Inference
<br/>
• FP32 Inference |
[
Intel
](
https://github.com/IntelAI
)
|
|
[
ResNet 50
](
https://github.com/IntelAI/models/tree/master/benchmarks/image_recognition/tensorflow/resnet50
)
|
[
Deep Residual Learning for Image Recognition
](
https://arxiv.org/pdf/1512.03385
)
| • Int8 Inference
<br/>
• FP32 Inference |
[
Intel
](
https://github.com/IntelAI
)
|
|
[
ResNet 50v1.5
](
https://github.com/IntelAI/models/tree/master/benchmarks/image_recognition/tensorflow/resnet50v1_5
)
|
[
Deep Residual Learning for Image Recognition
](
https://arxiv.org/pdf/1512.03385
)
| • Int8 Inference
<br/>
• FP32 Inference
<br/>
• FP32 Training |
[
Intel
](
https://github.com/IntelAI
)
|
|
[
ResNet 50v1.5
](
https://github.com/IntelAI/models/tree/master/benchmarks/image_recognition/tensorflow/resnet50v1_5
)
|
[
Deep Residual Learning for Image Recognition
](
https://arxiv.org/pdf/1512.03385
)
| • Int8 Inference
<br/>
• FP32 Inference
<br/>
• FP32 Training |
[
Intel
](
https://github.com/IntelAI
)
|
### Object Detection
| Model | Paper | Features | Maintainer |
|-------|-------|----------|------------|
|
[
R-FCN
](
https://github.com/IntelAI/models/tree/master/benchmarks/object_detection/tensorflow/rfcn
)
|
[
R-FCN: Object Detection<br/>via Region-based Fully Convolutional Networks
](
https://arxiv.org/pdf/1605.06409
)
| • Int8 Inference
<br/>
• FP32 Inference |
[
Intel
](
https://github.com/IntelAI
)
|
|
[
SSD-MobileNet
](
https://github.com/IntelAI/models/tree/master/benchmarks/object_detection/tensorflow/ssd-mobilenet
)
|
[
MobileNets: Efficient Convolutional Neural Networks<br/>for Mobile Vision Applications
](
https://arxiv.org/pdf/1704.04861
)
| • Int8 Inference
<br/>
• FP32 Inference |
[
Intel
](
https://github.com/IntelAI
)
|
|
[
SSD-ResNet34
](
https://github.com/IntelAI/models/tree/master/benchmarks/object_detection/tensorflow/ssd-resnet34
)
|
[
SSD: Single Shot MultiBox Detector
](
https://arxiv.org/pdf/1512.02325
)
| • Int8 Inference
<br/>
• FP32 Inference
<br/>
• FP32 Training |
[
Intel
](
https://github.com/IntelAI
)
|
### Segmentation
### Segmentation
| Model | Paper | Features | Maintainer |
| Model | Paper | Features | Maintainer |
...
...
official/benchmark/resnet_ctl_imagenet_benchmark.py
View file @
e7667f6f
...
@@ -38,13 +38,18 @@ FLAGS = flags.FLAGS
...
@@ -38,13 +38,18 @@ FLAGS = flags.FLAGS
class
CtlBenchmark
(
PerfZeroBenchmark
):
class
CtlBenchmark
(
PerfZeroBenchmark
):
"""Base benchmark class with methods to simplify testing."""
"""Base benchmark class with methods to simplify testing."""
def
__init__
(
self
,
output_dir
=
None
,
default_flags
=
None
,
flag_methods
=
None
):
def
__init__
(
self
,
output_dir
=
None
,
default_flags
=
None
,
flag_methods
=
None
,
**
kwargs
):
self
.
default_flags
=
default_flags
or
{}
self
.
default_flags
=
default_flags
or
{}
self
.
flag_methods
=
flag_methods
or
{}
self
.
flag_methods
=
flag_methods
or
{}
super
(
CtlBenchmark
,
self
).
__init__
(
super
(
CtlBenchmark
,
self
).
__init__
(
output_dir
=
output_dir
,
output_dir
=
output_dir
,
default_flags
=
self
.
default_flags
,
default_flags
=
self
.
default_flags
,
flag_methods
=
self
.
flag_methods
)
flag_methods
=
self
.
flag_methods
,
**
kwargs
)
def
_report_benchmark
(
self
,
def
_report_benchmark
(
self
,
stats
,
stats
,
...
@@ -190,13 +195,14 @@ class Resnet50CtlAccuracy(CtlBenchmark):
...
@@ -190,13 +195,14 @@ class Resnet50CtlAccuracy(CtlBenchmark):
class
Resnet50CtlBenchmarkBase
(
CtlBenchmark
):
class
Resnet50CtlBenchmarkBase
(
CtlBenchmark
):
"""Resnet50 benchmarks."""
"""Resnet50 benchmarks."""
def
__init__
(
self
,
output_dir
=
None
,
default_flags
=
None
):
def
__init__
(
self
,
output_dir
=
None
,
default_flags
=
None
,
**
kwargs
):
flag_methods
=
[
common
.
define_keras_flags
]
flag_methods
=
[
common
.
define_keras_flags
]
super
(
Resnet50CtlBenchmarkBase
,
self
).
__init__
(
super
(
Resnet50CtlBenchmarkBase
,
self
).
__init__
(
output_dir
=
output_dir
,
output_dir
=
output_dir
,
flag_methods
=
flag_methods
,
flag_methods
=
flag_methods
,
default_flags
=
default_flags
)
default_flags
=
default_flags
,
**
kwargs
)
@
benchmark_wrappers
.
enable_runtime_flags
@
benchmark_wrappers
.
enable_runtime_flags
def
_run_and_report_benchmark
(
self
):
def
_run_and_report_benchmark
(
self
):
...
@@ -381,12 +387,14 @@ class Resnet50CtlBenchmarkBase(CtlBenchmark):
...
@@ -381,12 +387,14 @@ class Resnet50CtlBenchmarkBase(CtlBenchmark):
FLAGS
.
single_l2_loss_op
=
True
FLAGS
.
single_l2_loss_op
=
True
FLAGS
.
use_tf_function
=
True
FLAGS
.
use_tf_function
=
True
FLAGS
.
enable_checkpoint_and_export
=
False
FLAGS
.
enable_checkpoint_and_export
=
False
FLAGS
.
data_dir
=
'gs://mlcompass-data/imagenet/imagenet-2012-tfrecord'
def
benchmark_2x2_tpu_bf16
(
self
):
def
benchmark_2x2_tpu_bf16
(
self
):
self
.
_setup
()
self
.
_setup
()
self
.
_set_df_common
()
self
.
_set_df_common
()
FLAGS
.
batch_size
=
1024
FLAGS
.
batch_size
=
1024
FLAGS
.
dtype
=
'bf16'
FLAGS
.
dtype
=
'bf16'
FLAGS
.
model_dir
=
self
.
_get_model_dir
(
'benchmark_2x2_tpu_bf16'
)
self
.
_run_and_report_benchmark
()
self
.
_run_and_report_benchmark
()
@
owner_utils
.
Owner
(
'tf-graph-compiler'
)
@
owner_utils
.
Owner
(
'tf-graph-compiler'
)
...
@@ -396,6 +404,7 @@ class Resnet50CtlBenchmarkBase(CtlBenchmark):
...
@@ -396,6 +404,7 @@ class Resnet50CtlBenchmarkBase(CtlBenchmark):
FLAGS
.
batch_size
=
1024
FLAGS
.
batch_size
=
1024
FLAGS
.
dtype
=
'bf16'
FLAGS
.
dtype
=
'bf16'
tf
.
config
.
experimental
.
enable_mlir_bridge
()
tf
.
config
.
experimental
.
enable_mlir_bridge
()
FLAGS
.
model_dir
=
self
.
_get_model_dir
(
'benchmark_2x2_tpu_bf16_mlir'
)
self
.
_run_and_report_benchmark
()
self
.
_run_and_report_benchmark
()
def
benchmark_4x4_tpu_bf16
(
self
):
def
benchmark_4x4_tpu_bf16
(
self
):
...
@@ -403,6 +412,7 @@ class Resnet50CtlBenchmarkBase(CtlBenchmark):
...
@@ -403,6 +412,7 @@ class Resnet50CtlBenchmarkBase(CtlBenchmark):
self
.
_set_df_common
()
self
.
_set_df_common
()
FLAGS
.
batch_size
=
4096
FLAGS
.
batch_size
=
4096
FLAGS
.
dtype
=
'bf16'
FLAGS
.
dtype
=
'bf16'
FLAGS
.
model_dir
=
self
.
_get_model_dir
(
'benchmark_4x4_tpu_bf16'
)
self
.
_run_and_report_benchmark
()
self
.
_run_and_report_benchmark
()
@
owner_utils
.
Owner
(
'tf-graph-compiler'
)
@
owner_utils
.
Owner
(
'tf-graph-compiler'
)
...
@@ -412,6 +422,7 @@ class Resnet50CtlBenchmarkBase(CtlBenchmark):
...
@@ -412,6 +422,7 @@ class Resnet50CtlBenchmarkBase(CtlBenchmark):
self
.
_set_df_common
()
self
.
_set_df_common
()
FLAGS
.
batch_size
=
4096
FLAGS
.
batch_size
=
4096
FLAGS
.
dtype
=
'bf16'
FLAGS
.
dtype
=
'bf16'
FLAGS
.
model_dir
=
self
.
_get_model_dir
(
'benchmark_4x4_tpu_bf16_mlir'
)
tf
.
config
.
experimental
.
enable_mlir_bridge
()
tf
.
config
.
experimental
.
enable_mlir_bridge
()
self
.
_run_and_report_benchmark
()
self
.
_run_and_report_benchmark
()
...
@@ -439,7 +450,7 @@ class Resnet50CtlBenchmarkSynth(Resnet50CtlBenchmarkBase):
...
@@ -439,7 +450,7 @@ class Resnet50CtlBenchmarkSynth(Resnet50CtlBenchmarkBase):
def_flags
[
'log_steps'
]
=
10
def_flags
[
'log_steps'
]
=
10
super
(
Resnet50CtlBenchmarkSynth
,
self
).
__init__
(
super
(
Resnet50CtlBenchmarkSynth
,
self
).
__init__
(
output_dir
=
output_dir
,
default_flags
=
def_flags
)
output_dir
=
output_dir
,
default_flags
=
def_flags
,
**
kwargs
)
class
Resnet50CtlBenchmarkReal
(
Resnet50CtlBenchmarkBase
):
class
Resnet50CtlBenchmarkReal
(
Resnet50CtlBenchmarkBase
):
...
@@ -454,7 +465,7 @@ class Resnet50CtlBenchmarkReal(Resnet50CtlBenchmarkBase):
...
@@ -454,7 +465,7 @@ class Resnet50CtlBenchmarkReal(Resnet50CtlBenchmarkBase):
def_flags
[
'log_steps'
]
=
10
def_flags
[
'log_steps'
]
=
10
super
(
Resnet50CtlBenchmarkReal
,
self
).
__init__
(
super
(
Resnet50CtlBenchmarkReal
,
self
).
__init__
(
output_dir
=
output_dir
,
default_flags
=
def_flags
)
output_dir
=
output_dir
,
default_flags
=
def_flags
,
**
kwargs
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
...
...
official/modeling/hyperparams/config_definitions.py
View file @
e7667f6f
...
@@ -112,6 +112,8 @@ class RuntimeConfig(base_config.Config):
...
@@ -112,6 +112,8 @@ class RuntimeConfig(base_config.Config):
run_eagerly: Whether or not to run the experiment eagerly.
run_eagerly: Whether or not to run the experiment eagerly.
batchnorm_spatial_persistent: Whether or not to enable the spatial
batchnorm_spatial_persistent: Whether or not to enable the spatial
persistent mode for CuDNN batch norm kernel for improved GPU performance.
persistent mode for CuDNN batch norm kernel for improved GPU performance.
allow_tpu_summary: Whether to allow summary happen inside the XLA program
runs on TPU through automatic outside compilation.
"""
"""
distribution_strategy
:
str
=
"mirrored"
distribution_strategy
:
str
=
"mirrored"
enable_xla
:
bool
=
False
enable_xla
:
bool
=
False
...
@@ -183,14 +185,19 @@ class TrainerConfig(base_config.Config):
...
@@ -183,14 +185,19 @@ class TrainerConfig(base_config.Config):
validation_interval: number of training steps to run between evaluations.
validation_interval: number of training steps to run between evaluations.
"""
"""
optimizer_config
:
OptimizationConfig
=
OptimizationConfig
()
optimizer_config
:
OptimizationConfig
=
OptimizationConfig
()
# Orbit settings.
train_tf_while_loop
:
bool
=
True
train_tf_while_loop
:
bool
=
True
train_tf_function
:
bool
=
True
train_tf_function
:
bool
=
True
eval_tf_function
:
bool
=
True
eval_tf_function
:
bool
=
True
allow_tpu_summary
:
bool
=
False
# Trainer intervals.
steps_per_loop
:
int
=
1000
steps_per_loop
:
int
=
1000
summary_interval
:
int
=
1000
summary_interval
:
int
=
1000
checkpoint_interval
:
int
=
1000
checkpoint_interval
:
int
=
1000
# Checkpoint manager.
max_to_keep
:
int
=
5
max_to_keep
:
int
=
5
continuous_eval_timeout
:
Optional
[
int
]
=
None
continuous_eval_timeout
:
Optional
[
int
]
=
None
# Train/Eval routines.
train_steps
:
int
=
0
train_steps
:
int
=
0
validation_steps
:
Optional
[
int
]
=
None
validation_steps
:
Optional
[
int
]
=
None
validation_interval
:
int
=
1000
validation_interval
:
int
=
1000
...
...
official/nlp/configs/electra.py
View file @
e7667f6f
...
@@ -34,6 +34,8 @@ class ELECTRAPretrainerConfig(base_config.Config):
...
@@ -34,6 +34,8 @@ class ELECTRAPretrainerConfig(base_config.Config):
sequence_length
:
int
=
512
sequence_length
:
int
=
512
num_classes
:
int
=
2
num_classes
:
int
=
2
discriminator_loss_weight
:
float
=
50.0
discriminator_loss_weight
:
float
=
50.0
tie_embeddings
:
bool
=
True
disallow_correct
:
bool
=
False
generator_encoder
:
encoders
.
TransformerEncoderConfig
=
(
generator_encoder
:
encoders
.
TransformerEncoderConfig
=
(
encoders
.
TransformerEncoderConfig
())
encoders
.
TransformerEncoderConfig
())
discriminator_encoder
:
encoders
.
TransformerEncoderConfig
=
(
discriminator_encoder
:
encoders
.
TransformerEncoderConfig
=
(
...
@@ -60,23 +62,30 @@ def instantiate_pretrainer_from_cfg(
...
@@ -60,23 +62,30 @@ def instantiate_pretrainer_from_cfg(
"""Instantiates ElectraPretrainer from the config."""
"""Instantiates ElectraPretrainer from the config."""
generator_encoder_cfg
=
config
.
generator_encoder
generator_encoder_cfg
=
config
.
generator_encoder
discriminator_encoder_cfg
=
config
.
discriminator_encoder
discriminator_encoder_cfg
=
config
.
discriminator_encoder
if
generator_network
is
None
:
# Copy discriminator's embeddings to generator for easier model serialization.
generator_network
=
encoders
.
instantiate_encoder_from_cfg
(
generator_encoder_cfg
)
if
discriminator_network
is
None
:
if
discriminator_network
is
None
:
discriminator_network
=
encoders
.
instantiate_encoder_from_cfg
(
discriminator_network
=
encoders
.
instantiate_encoder_from_cfg
(
discriminator_encoder_cfg
)
discriminator_encoder_cfg
)
if
generator_network
is
None
:
if
config
.
tie_embeddings
:
embedding_layer
=
discriminator_network
.
get_embedding_layer
()
generator_network
=
encoders
.
instantiate_encoder_from_cfg
(
generator_encoder_cfg
,
embedding_layer
=
embedding_layer
)
else
:
generator_network
=
encoders
.
instantiate_encoder_from_cfg
(
generator_encoder_cfg
)
return
electra_pretrainer
.
ElectraPretrainer
(
return
electra_pretrainer
.
ElectraPretrainer
(
generator_network
=
generator_network
,
generator_network
=
generator_network
,
discriminator_network
=
discriminator_network
,
discriminator_network
=
discriminator_network
,
vocab_size
=
config
.
generator_encoder
.
vocab_size
,
vocab_size
=
config
.
generator_encoder
.
vocab_size
,
num_classes
=
config
.
num_classes
,
num_classes
=
config
.
num_classes
,
sequence_length
=
config
.
sequence_length
,
sequence_length
=
config
.
sequence_length
,
last_hidden_dim
=
config
.
generator_encoder
.
hidden_size
,
num_token_predictions
=
config
.
num_masked_tokens
,
num_token_predictions
=
config
.
num_masked_tokens
,
mlm_activation
=
tf_utils
.
get_activation
(
mlm_activation
=
tf_utils
.
get_activation
(
generator_encoder_cfg
.
hidden_activation
),
generator_encoder_cfg
.
hidden_activation
),
mlm_initializer
=
tf
.
keras
.
initializers
.
TruncatedNormal
(
mlm_initializer
=
tf
.
keras
.
initializers
.
TruncatedNormal
(
stddev
=
generator_encoder_cfg
.
initializer_range
),
stddev
=
generator_encoder_cfg
.
initializer_range
),
classification_heads
=
instantiate_classification_heads_from_cfgs
(
classification_heads
=
instantiate_classification_heads_from_cfgs
(
config
.
cls_heads
))
config
.
cls_heads
),
disallow_correct
=
config
.
disallow_correct
)
official/nlp/configs/encoders.py
View file @
e7667f6f
...
@@ -17,12 +17,13 @@
...
@@ -17,12 +17,13 @@
Includes configurations and instantiation methods.
Includes configurations and instantiation methods.
"""
"""
from
typing
import
Optional
import
dataclasses
import
dataclasses
import
gin
import
tensorflow
as
tf
import
tensorflow
as
tf
from
official.modeling
import
tf_utils
from
official.modeling
import
tf_utils
from
official.modeling.hyperparams
import
base_config
from
official.modeling.hyperparams
import
base_config
from
official.nlp.modeling
import
layers
from
official.nlp.modeling
import
networks
from
official.nlp.modeling
import
networks
...
@@ -40,11 +41,13 @@ class TransformerEncoderConfig(base_config.Config):
...
@@ -40,11 +41,13 @@ class TransformerEncoderConfig(base_config.Config):
max_position_embeddings
:
int
=
512
max_position_embeddings
:
int
=
512
type_vocab_size
:
int
=
2
type_vocab_size
:
int
=
2
initializer_range
:
float
=
0.02
initializer_range
:
float
=
0.02
embedding_size
:
Optional
[
int
]
=
None
@
gin
.
configurable
def
instantiate_encoder_from_cfg
(
def
instantiate_encoder_from_cfg
(
config
:
TransformerEncoderConfig
,
config
:
TransformerEncoderConfig
,
encoder_cls
=
networks
.
TransformerEncoder
):
encoder_cls
=
networks
.
TransformerEncoder
,
embedding_layer
:
Optional
[
layers
.
OnDeviceEmbedding
]
=
None
):
"""Instantiate a Transformer encoder network from TransformerEncoderConfig."""
"""Instantiate a Transformer encoder network from TransformerEncoderConfig."""
if
encoder_cls
.
__name__
==
"EncoderScaffold"
:
if
encoder_cls
.
__name__
==
"EncoderScaffold"
:
embedding_cfg
=
dict
(
embedding_cfg
=
dict
(
...
@@ -91,5 +94,7 @@ def instantiate_encoder_from_cfg(config: TransformerEncoderConfig,
...
@@ -91,5 +94,7 @@ def instantiate_encoder_from_cfg(config: TransformerEncoderConfig,
max_sequence_length
=
config
.
max_position_embeddings
,
max_sequence_length
=
config
.
max_position_embeddings
,
type_vocab_size
=
config
.
type_vocab_size
,
type_vocab_size
=
config
.
type_vocab_size
,
initializer
=
tf
.
keras
.
initializers
.
TruncatedNormal
(
initializer
=
tf
.
keras
.
initializers
.
TruncatedNormal
(
stddev
=
config
.
initializer_range
))
stddev
=
config
.
initializer_range
),
embedding_width
=
config
.
embedding_size
,
embedding_layer
=
embedding_layer
)
return
encoder_network
return
encoder_network
official/nlp/modeling/models/bert_classifier.py
View file @
e7667f6f
...
@@ -37,6 +37,9 @@ class BertClassifier(tf.keras.Model):
...
@@ -37,6 +37,9 @@ class BertClassifier(tf.keras.Model):
instantiates a classification network based on the passed `num_classes`
instantiates a classification network based on the passed `num_classes`
argument. If `num_classes` is set to 1, a regression network is instantiated.
argument. If `num_classes` is set to 1, a regression network is instantiated.
*Note* that the model is constructed by
[Keras Functional API](https://keras.io/guides/functional_api/).
Arguments:
Arguments:
network: A transformer network. This network should output a sequence output
network: A transformer network. This network should output a sequence output
and a classification output. Furthermore, it should expose its embedding
and a classification output. Furthermore, it should expose its embedding
...
...
official/nlp/modeling/models/bert_pretrainer.py
View file @
e7667f6f
...
@@ -41,6 +41,9 @@ class BertPretrainer(tf.keras.Model):
...
@@ -41,6 +41,9 @@ class BertPretrainer(tf.keras.Model):
instantiates the masked language model and classification networks that are
instantiates the masked language model and classification networks that are
used to create the training objectives.
used to create the training objectives.
*Note* that the model is constructed by
[Keras Functional API](https://keras.io/guides/functional_api/).
Arguments:
Arguments:
network: A transformer network. This network should output a sequence output
network: A transformer network. This network should output a sequence output
and a classification output.
and a classification output.
...
...
official/nlp/modeling/models/bert_span_labeler.py
View file @
e7667f6f
...
@@ -32,9 +32,12 @@ class BertSpanLabeler(tf.keras.Model):
...
@@ -32,9 +32,12 @@ class BertSpanLabeler(tf.keras.Model):
encoder as described in "BERT: Pre-training of Deep Bidirectional Transformers
encoder as described in "BERT: Pre-training of Deep Bidirectional Transformers
for Language Understanding" (https://arxiv.org/abs/1810.04805).
for Language Understanding" (https://arxiv.org/abs/1810.04805).
The BertSpanLabeler allows a user to pass in a transformer
stack
, and
The BertSpanLabeler allows a user to pass in a transformer
encoder
, and
instantiates a span labeling network based on a single dense layer.
instantiates a span labeling network based on a single dense layer.
*Note* that the model is constructed by
[Keras Functional API](https://keras.io/guides/functional_api/).
Arguments:
Arguments:
network: A transformer network. This network should output a sequence output
network: A transformer network. This network should output a sequence output
and a classification output. Furthermore, it should expose its embedding
and a classification output. Furthermore, it should expose its embedding
...
...
official/nlp/modeling/models/bert_token_classifier.py
View file @
e7667f6f
...
@@ -36,6 +36,9 @@ class BertTokenClassifier(tf.keras.Model):
...
@@ -36,6 +36,9 @@ class BertTokenClassifier(tf.keras.Model):
instantiates a token classification network based on the passed `num_classes`
instantiates a token classification network based on the passed `num_classes`
argument.
argument.
*Note* that the model is constructed by
[Keras Functional API](https://keras.io/guides/functional_api/).
Arguments:
Arguments:
network: A transformer network. This network should output a sequence output
network: A transformer network. This network should output a sequence output
and a classification output. Furthermore, it should expose its embedding
and a classification output. Furthermore, it should expose its embedding
...
...
official/nlp/modeling/models/electra_pretrainer.py
View file @
e7667f6f
...
@@ -39,6 +39,9 @@ class ElectraPretrainer(tf.keras.Model):
...
@@ -39,6 +39,9 @@ class ElectraPretrainer(tf.keras.Model):
model (at generator side) and classification networks (at discriminator side)
model (at generator side) and classification networks (at discriminator side)
that are used to create the training objectives.
that are used to create the training objectives.
*Note* that the model is constructed by Keras Subclass API, where layers are
defined inside __init__ and call() implements the computation.
Arguments:
Arguments:
generator_network: A transformer network for generator, this network should
generator_network: A transformer network for generator, this network should
output a sequence output and an optional classification output.
output a sequence output and an optional classification output.
...
@@ -48,7 +51,6 @@ class ElectraPretrainer(tf.keras.Model):
...
@@ -48,7 +51,6 @@ class ElectraPretrainer(tf.keras.Model):
num_classes: Number of classes to predict from the classification network
num_classes: Number of classes to predict from the classification network
for the generator network (not used now)
for the generator network (not used now)
sequence_length: Input sequence length
sequence_length: Input sequence length
last_hidden_dim: Last hidden dim of generator transformer output
num_token_predictions: Number of tokens to predict from the masked LM.
num_token_predictions: Number of tokens to predict from the masked LM.
mlm_activation: The activation (if any) to use in the masked LM and
mlm_activation: The activation (if any) to use in the masked LM and
classification networks. If None, no activation will be used.
classification networks. If None, no activation will be used.
...
@@ -66,7 +68,6 @@ class ElectraPretrainer(tf.keras.Model):
...
@@ -66,7 +68,6 @@ class ElectraPretrainer(tf.keras.Model):
vocab_size
,
vocab_size
,
num_classes
,
num_classes
,
sequence_length
,
sequence_length
,
last_hidden_dim
,
num_token_predictions
,
num_token_predictions
,
mlm_activation
=
None
,
mlm_activation
=
None
,
mlm_initializer
=
'glorot_uniform'
,
mlm_initializer
=
'glorot_uniform'
,
...
@@ -80,7 +81,6 @@ class ElectraPretrainer(tf.keras.Model):
...
@@ -80,7 +81,6 @@ class ElectraPretrainer(tf.keras.Model):
'vocab_size'
:
vocab_size
,
'vocab_size'
:
vocab_size
,
'num_classes'
:
num_classes
,
'num_classes'
:
num_classes
,
'sequence_length'
:
sequence_length
,
'sequence_length'
:
sequence_length
,
'last_hidden_dim'
:
last_hidden_dim
,
'num_token_predictions'
:
num_token_predictions
,
'num_token_predictions'
:
num_token_predictions
,
'mlm_activation'
:
mlm_activation
,
'mlm_activation'
:
mlm_activation
,
'mlm_initializer'
:
mlm_initializer
,
'mlm_initializer'
:
mlm_initializer
,
...
@@ -95,7 +95,6 @@ class ElectraPretrainer(tf.keras.Model):
...
@@ -95,7 +95,6 @@ class ElectraPretrainer(tf.keras.Model):
self
.
vocab_size
=
vocab_size
self
.
vocab_size
=
vocab_size
self
.
num_classes
=
num_classes
self
.
num_classes
=
num_classes
self
.
sequence_length
=
sequence_length
self
.
sequence_length
=
sequence_length
self
.
last_hidden_dim
=
last_hidden_dim
self
.
num_token_predictions
=
num_token_predictions
self
.
num_token_predictions
=
num_token_predictions
self
.
mlm_activation
=
mlm_activation
self
.
mlm_activation
=
mlm_activation
self
.
mlm_initializer
=
mlm_initializer
self
.
mlm_initializer
=
mlm_initializer
...
@@ -108,10 +107,15 @@ class ElectraPretrainer(tf.keras.Model):
...
@@ -108,10 +107,15 @@ class ElectraPretrainer(tf.keras.Model):
output
=
output_type
,
output
=
output_type
,
name
=
'generator_masked_lm'
)
name
=
'generator_masked_lm'
)
self
.
classification
=
layers
.
ClassificationHead
(
self
.
classification
=
layers
.
ClassificationHead
(
inner_dim
=
last_
hidden_
dim
,
inner_dim
=
generator_network
.
_config_dict
[
'
hidden_
size'
]
,
num_classes
=
num_classes
,
num_classes
=
num_classes
,
initializer
=
mlm_initializer
,
initializer
=
mlm_initializer
,
name
=
'generator_classification_head'
)
name
=
'generator_classification_head'
)
self
.
discriminator_projection
=
tf
.
keras
.
layers
.
Dense
(
units
=
discriminator_network
.
_config_dict
[
'hidden_size'
],
activation
=
mlm_activation
,
kernel_initializer
=
mlm_initializer
,
name
=
'discriminator_projection_head'
)
self
.
discriminator_head
=
tf
.
keras
.
layers
.
Dense
(
self
.
discriminator_head
=
tf
.
keras
.
layers
.
Dense
(
units
=
1
,
kernel_initializer
=
mlm_initializer
)
units
=
1
,
kernel_initializer
=
mlm_initializer
)
...
@@ -165,7 +169,8 @@ class ElectraPretrainer(tf.keras.Model):
...
@@ -165,7 +169,8 @@ class ElectraPretrainer(tf.keras.Model):
if
isinstance
(
disc_sequence_output
,
list
):
if
isinstance
(
disc_sequence_output
,
list
):
disc_sequence_output
=
disc_sequence_output
[
-
1
]
disc_sequence_output
=
disc_sequence_output
[
-
1
]
disc_logits
=
self
.
discriminator_head
(
disc_sequence_output
)
disc_logits
=
self
.
discriminator_head
(
self
.
discriminator_projection
(
disc_sequence_output
))
disc_logits
=
tf
.
squeeze
(
disc_logits
,
axis
=-
1
)
disc_logits
=
tf
.
squeeze
(
disc_logits
,
axis
=-
1
)
outputs
=
{
outputs
=
{
...
@@ -214,6 +219,12 @@ class ElectraPretrainer(tf.keras.Model):
...
@@ -214,6 +219,12 @@ class ElectraPretrainer(tf.keras.Model):
'sampled_tokens'
:
sampled_tokens
'sampled_tokens'
:
sampled_tokens
}
}
@
property
def
checkpoint_items
(
self
):
"""Returns a dictionary of items to be additionally checkpointed."""
items
=
dict
(
encoder
=
self
.
discriminator_network
)
return
items
def
get_config
(
self
):
def
get_config
(
self
):
return
self
.
_config
return
self
.
_config
...
...
official/nlp/modeling/models/electra_pretrainer_test.py
View file @
e7667f6f
...
@@ -49,7 +49,6 @@ class ElectraPretrainerTest(keras_parameterized.TestCase):
...
@@ -49,7 +49,6 @@ class ElectraPretrainerTest(keras_parameterized.TestCase):
vocab_size
=
vocab_size
,
vocab_size
=
vocab_size
,
num_classes
=
num_classes
,
num_classes
=
num_classes
,
sequence_length
=
sequence_length
,
sequence_length
=
sequence_length
,
last_hidden_dim
=
768
,
num_token_predictions
=
num_token_predictions
,
num_token_predictions
=
num_token_predictions
,
disallow_correct
=
True
)
disallow_correct
=
True
)
...
@@ -101,7 +100,6 @@ class ElectraPretrainerTest(keras_parameterized.TestCase):
...
@@ -101,7 +100,6 @@ class ElectraPretrainerTest(keras_parameterized.TestCase):
vocab_size
=
100
,
vocab_size
=
100
,
num_classes
=
2
,
num_classes
=
2
,
sequence_length
=
3
,
sequence_length
=
3
,
last_hidden_dim
=
768
,
num_token_predictions
=
2
)
num_token_predictions
=
2
)
# Create a set of 2-dimensional data tensors to feed into the model.
# Create a set of 2-dimensional data tensors to feed into the model.
...
@@ -140,7 +138,6 @@ class ElectraPretrainerTest(keras_parameterized.TestCase):
...
@@ -140,7 +138,6 @@ class ElectraPretrainerTest(keras_parameterized.TestCase):
vocab_size
=
100
,
vocab_size
=
100
,
num_classes
=
2
,
num_classes
=
2
,
sequence_length
=
3
,
sequence_length
=
3
,
last_hidden_dim
=
768
,
num_token_predictions
=
2
)
num_token_predictions
=
2
)
# Create another BERT trainer via serialization and deserialization.
# Create another BERT trainer via serialization and deserialization.
...
...
official/nlp/modeling/networks/albert_transformer_encoder.py
View file @
e7667f6f
...
@@ -40,6 +40,8 @@ class AlbertTransformerEncoder(tf.keras.Model):
...
@@ -40,6 +40,8 @@ class AlbertTransformerEncoder(tf.keras.Model):
The default values for this object are taken from the ALBERT-Base
The default values for this object are taken from the ALBERT-Base
implementation described in the paper.
implementation described in the paper.
*Note* that the network is constructed by Keras Functional API.
Arguments:
Arguments:
vocab_size: The size of the token vocabulary.
vocab_size: The size of the token vocabulary.
embedding_width: The width of the word embeddings. If the embedding width is
embedding_width: The width of the word embeddings. If the embedding width is
...
...
official/nlp/modeling/networks/classification.py
View file @
e7667f6f
...
@@ -29,6 +29,9 @@ class Classification(tf.keras.Model):
...
@@ -29,6 +29,9 @@ class Classification(tf.keras.Model):
This network implements a simple classifier head based on a dense layer. If
This network implements a simple classifier head based on a dense layer. If
num_classes is one, it can be considered as a regression problem.
num_classes is one, it can be considered as a regression problem.
*Note* that the network is constructed by
[Keras Functional API](https://keras.io/guides/functional_api/).
Arguments:
Arguments:
input_width: The innermost dimension of the input tensor to this network.
input_width: The innermost dimension of the input tensor to this network.
num_classes: The number of classes that this network should classify to. If
num_classes: The number of classes that this network should classify to. If
...
...
official/nlp/modeling/networks/encoder_scaffold.py
View file @
e7667f6f
...
@@ -49,6 +49,9 @@ class EncoderScaffold(tf.keras.Model):
...
@@ -49,6 +49,9 @@ class EncoderScaffold(tf.keras.Model):
If the hidden_cls is not overridden, a default transformer layer will be
If the hidden_cls is not overridden, a default transformer layer will be
instantiated.
instantiated.
*Note* that the network is constructed by
[Keras Functional API](https://keras.io/guides/functional_api/).
Arguments:
Arguments:
pooled_output_dim: The dimension of pooled output.
pooled_output_dim: The dimension of pooled output.
pooler_layer_initializer: The initializer for the classification
pooler_layer_initializer: The initializer for the classification
...
...
official/nlp/modeling/networks/span_labeling.py
View file @
e7667f6f
...
@@ -27,6 +27,8 @@ class SpanLabeling(tf.keras.Model):
...
@@ -27,6 +27,8 @@ class SpanLabeling(tf.keras.Model):
"""Span labeling network head for BERT modeling.
"""Span labeling network head for BERT modeling.
This network implements a simple single-span labeler based on a dense layer.
This network implements a simple single-span labeler based on a dense layer.
*Note* that the network is constructed by
[Keras Functional API](https://keras.io/guides/functional_api/).
Arguments:
Arguments:
input_width: The innermost dimension of the input tensor to this network.
input_width: The innermost dimension of the input tensor to this network.
...
...
official/nlp/modeling/networks/token_classification.py
View file @
e7667f6f
...
@@ -27,6 +27,8 @@ class TokenClassification(tf.keras.Model):
...
@@ -27,6 +27,8 @@ class TokenClassification(tf.keras.Model):
"""TokenClassification network head for BERT modeling.
"""TokenClassification network head for BERT modeling.
This network implements a simple token classifier head based on a dense layer.
This network implements a simple token classifier head based on a dense layer.
*Note* that the network is constructed by
[Keras Functional API](https://keras.io/guides/functional_api/).
Arguments:
Arguments:
input_width: The innermost dimension of the input tensor to this network.
input_width: The innermost dimension of the input tensor to this network.
...
...
official/nlp/modeling/networks/transformer_encoder.py
View file @
e7667f6f
...
@@ -39,6 +39,9 @@ class TransformerEncoder(tf.keras.Model):
...
@@ -39,6 +39,9 @@ class TransformerEncoder(tf.keras.Model):
in "BERT: Pre-training of Deep Bidirectional Transformers for Language
in "BERT: Pre-training of Deep Bidirectional Transformers for Language
Understanding".
Understanding".
*Note* that the network is constructed by
[Keras Functional API](https://keras.io/guides/functional_api/).
Arguments:
Arguments:
vocab_size: The size of the token vocabulary.
vocab_size: The size of the token vocabulary.
hidden_size: The size of the transformer hidden layers.
hidden_size: The size of the transformer hidden layers.
...
...
official/nlp/tasks/electra_task.py
0 → 100644
View file @
e7667f6f
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""ELECTRA pretraining task (Joint Masked LM and Replaced Token Detection)."""
import
dataclasses
import
tensorflow
as
tf
from
official.core
import
base_task
from
official.modeling.hyperparams
import
config_definitions
as
cfg
from
official.nlp.configs
import
bert
from
official.nlp.configs
import
electra
from
official.nlp.data
import
pretrain_dataloader
@
dataclasses
.
dataclass
class
ELECTRAPretrainConfig
(
cfg
.
TaskConfig
):
"""The model config."""
model
:
electra
.
ELECTRAPretrainerConfig
=
electra
.
ELECTRAPretrainerConfig
(
cls_heads
=
[
bert
.
ClsHeadConfig
(
inner_dim
=
768
,
num_classes
=
2
,
dropout_rate
=
0.1
,
name
=
'next_sentence'
)
])
train_data
:
cfg
.
DataConfig
=
cfg
.
DataConfig
()
validation_data
:
cfg
.
DataConfig
=
cfg
.
DataConfig
()
@
base_task
.
register_task_cls
(
ELECTRAPretrainConfig
)
class
ELECTRAPretrainTask
(
base_task
.
Task
):
"""ELECTRA Pretrain Task (Masked LM + Replaced Token Detection)."""
def
build_model
(
self
):
return
electra
.
instantiate_pretrainer_from_cfg
(
self
.
task_config
.
model
)
def
build_losses
(
self
,
labels
,
model_outputs
,
metrics
,
aux_losses
=
None
)
->
tf
.
Tensor
:
metrics
=
dict
([(
metric
.
name
,
metric
)
for
metric
in
metrics
])
# generator lm and (optional) nsp loss.
lm_prediction_losses
=
tf
.
keras
.
losses
.
sparse_categorical_crossentropy
(
labels
[
'masked_lm_ids'
],
tf
.
cast
(
model_outputs
[
'lm_outputs'
],
tf
.
float32
),
from_logits
=
True
)
lm_label_weights
=
labels
[
'masked_lm_weights'
]
lm_numerator_loss
=
tf
.
reduce_sum
(
lm_prediction_losses
*
lm_label_weights
)
lm_denominator_loss
=
tf
.
reduce_sum
(
lm_label_weights
)
mlm_loss
=
tf
.
math
.
divide_no_nan
(
lm_numerator_loss
,
lm_denominator_loss
)
metrics
[
'lm_example_loss'
].
update_state
(
mlm_loss
)
if
'next_sentence_labels'
in
labels
:
sentence_labels
=
labels
[
'next_sentence_labels'
]
sentence_outputs
=
tf
.
cast
(
model_outputs
[
'sentence_outputs'
],
dtype
=
tf
.
float32
)
sentence_loss
=
tf
.
keras
.
losses
.
sparse_categorical_crossentropy
(
sentence_labels
,
sentence_outputs
,
from_logits
=
True
)
metrics
[
'next_sentence_loss'
].
update_state
(
sentence_loss
)
total_loss
=
mlm_loss
+
sentence_loss
else
:
total_loss
=
mlm_loss
# discriminator replaced token detection (rtd) loss.
rtd_logits
=
model_outputs
[
'disc_logits'
]
rtd_labels
=
tf
.
cast
(
model_outputs
[
'disc_label'
],
tf
.
float32
)
input_mask
=
tf
.
cast
(
labels
[
'input_mask'
],
tf
.
float32
)
rtd_ind_loss
=
tf
.
nn
.
sigmoid_cross_entropy_with_logits
(
logits
=
rtd_logits
,
labels
=
rtd_labels
)
rtd_numerator
=
tf
.
reduce_sum
(
input_mask
*
rtd_ind_loss
)
rtd_denominator
=
tf
.
reduce_sum
(
input_mask
)
rtd_loss
=
tf
.
math
.
divide_no_nan
(
rtd_numerator
,
rtd_denominator
)
metrics
[
'discriminator_loss'
].
update_state
(
rtd_loss
)
total_loss
=
total_loss
+
\
self
.
task_config
.
model
.
discriminator_loss_weight
*
rtd_loss
if
aux_losses
:
total_loss
+=
tf
.
add_n
(
aux_losses
)
metrics
[
'total_loss'
].
update_state
(
total_loss
)
return
total_loss
def
build_inputs
(
self
,
params
,
input_context
=
None
):
"""Returns tf.data.Dataset for pretraining."""
if
params
.
input_path
==
'dummy'
:
def
dummy_data
(
_
):
dummy_ids
=
tf
.
zeros
((
1
,
params
.
seq_length
),
dtype
=
tf
.
int32
)
dummy_lm
=
tf
.
zeros
((
1
,
params
.
max_predictions_per_seq
),
dtype
=
tf
.
int32
)
return
dict
(
input_word_ids
=
dummy_ids
,
input_mask
=
dummy_ids
,
input_type_ids
=
dummy_ids
,
masked_lm_positions
=
dummy_lm
,
masked_lm_ids
=
dummy_lm
,
masked_lm_weights
=
tf
.
cast
(
dummy_lm
,
dtype
=
tf
.
float32
),
next_sentence_labels
=
tf
.
zeros
((
1
,
1
),
dtype
=
tf
.
int32
))
dataset
=
tf
.
data
.
Dataset
.
range
(
1
)
dataset
=
dataset
.
repeat
()
dataset
=
dataset
.
map
(
dummy_data
,
num_parallel_calls
=
tf
.
data
.
experimental
.
AUTOTUNE
)
return
dataset
return
pretrain_dataloader
.
BertPretrainDataLoader
(
params
).
load
(
input_context
)
def
build_metrics
(
self
,
training
=
None
):
del
training
metrics
=
[
tf
.
keras
.
metrics
.
SparseCategoricalAccuracy
(
name
=
'masked_lm_accuracy'
),
tf
.
keras
.
metrics
.
Mean
(
name
=
'lm_example_loss'
),
tf
.
keras
.
metrics
.
SparseCategoricalAccuracy
(
name
=
'discriminator_accuracy'
),
]
if
self
.
task_config
.
train_data
.
use_next_sentence_label
:
metrics
.
append
(
tf
.
keras
.
metrics
.
SparseCategoricalAccuracy
(
name
=
'next_sentence_accuracy'
))
metrics
.
append
(
tf
.
keras
.
metrics
.
Mean
(
name
=
'next_sentence_loss'
))
metrics
.
append
(
tf
.
keras
.
metrics
.
Mean
(
name
=
'discriminator_loss'
))
metrics
.
append
(
tf
.
keras
.
metrics
.
Mean
(
name
=
'total_loss'
))
return
metrics
def
process_metrics
(
self
,
metrics
,
labels
,
model_outputs
):
metrics
=
dict
([(
metric
.
name
,
metric
)
for
metric
in
metrics
])
if
'masked_lm_accuracy'
in
metrics
:
metrics
[
'masked_lm_accuracy'
].
update_state
(
labels
[
'masked_lm_ids'
],
model_outputs
[
'lm_outputs'
],
labels
[
'masked_lm_weights'
])
if
'next_sentence_accuracy'
in
metrics
:
metrics
[
'next_sentence_accuracy'
].
update_state
(
labels
[
'next_sentence_labels'
],
model_outputs
[
'sentence_outputs'
])
if
'discriminator_accuracy'
in
metrics
:
disc_logits_expanded
=
tf
.
expand_dims
(
model_outputs
[
'disc_logits'
],
-
1
)
discrim_full_logits
=
tf
.
concat
(
[
-
1.0
*
disc_logits_expanded
,
disc_logits_expanded
],
-
1
)
metrics
[
'discriminator_accuracy'
].
update_state
(
model_outputs
[
'disc_label'
],
discrim_full_logits
,
labels
[
'input_mask'
])
def
train_step
(
self
,
inputs
,
model
:
tf
.
keras
.
Model
,
optimizer
:
tf
.
keras
.
optimizers
.
Optimizer
,
metrics
):
"""Does forward and backward.
Args:
inputs: a dictionary of input tensors.
model: the model, forward pass definition.
optimizer: the optimizer for this training step.
metrics: a nested structure of metrics objects.
Returns:
A dictionary of logs.
"""
with
tf
.
GradientTape
()
as
tape
:
outputs
=
model
(
inputs
,
training
=
True
)
# Computes per-replica loss.
loss
=
self
.
build_losses
(
labels
=
inputs
,
model_outputs
=
outputs
,
metrics
=
metrics
,
aux_losses
=
model
.
losses
)
# Scales loss as the default gradients allreduce performs sum inside the
# optimizer.
# TODO(b/154564893): enable loss scaling.
scaled_loss
=
loss
/
tf
.
distribute
.
get_strategy
().
num_replicas_in_sync
tvars
=
model
.
trainable_variables
grads
=
tape
.
gradient
(
scaled_loss
,
tvars
)
optimizer
.
apply_gradients
(
list
(
zip
(
grads
,
tvars
)))
self
.
process_metrics
(
metrics
,
inputs
,
outputs
)
return
{
self
.
loss
:
loss
}
def
validation_step
(
self
,
inputs
,
model
:
tf
.
keras
.
Model
,
metrics
):
"""Validatation step.
Args:
inputs: a dictionary of input tensors.
model: the keras.Model.
metrics: a nested structure of metrics objects.
Returns:
A dictionary of logs.
"""
outputs
=
model
(
inputs
,
training
=
False
)
loss
=
self
.
build_losses
(
labels
=
inputs
,
model_outputs
=
outputs
,
metrics
=
metrics
,
aux_losses
=
model
.
losses
)
self
.
process_metrics
(
metrics
,
inputs
,
outputs
)
return
{
self
.
loss
:
loss
}
official/nlp/tasks/electra_task_test.py
0 → 100644
View file @
e7667f6f
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for official.nlp.tasks.electra_task."""
import
tensorflow
as
tf
from
official.nlp.configs
import
bert
from
official.nlp.configs
import
electra
from
official.nlp.configs
import
encoders
from
official.nlp.data
import
pretrain_dataloader
from
official.nlp.tasks
import
electra_task
class
ELECTRAPretrainTaskTest
(
tf
.
test
.
TestCase
):
def
test_task
(
self
):
config
=
electra_task
.
ELECTRAPretrainConfig
(
model
=
electra
.
ELECTRAPretrainerConfig
(
generator_encoder
=
encoders
.
TransformerEncoderConfig
(
vocab_size
=
30522
,
num_layers
=
1
),
discriminator_encoder
=
encoders
.
TransformerEncoderConfig
(
vocab_size
=
30522
,
num_layers
=
1
),
num_masked_tokens
=
20
,
sequence_length
=
128
,
cls_heads
=
[
bert
.
ClsHeadConfig
(
inner_dim
=
10
,
num_classes
=
2
,
name
=
"next_sentence"
)
]),
train_data
=
pretrain_dataloader
.
BertPretrainDataConfig
(
input_path
=
"dummy"
,
max_predictions_per_seq
=
20
,
seq_length
=
128
,
global_batch_size
=
1
))
task
=
electra_task
.
ELECTRAPretrainTask
(
config
)
model
=
task
.
build_model
()
metrics
=
task
.
build_metrics
()
dataset
=
task
.
build_inputs
(
config
.
train_data
)
iterator
=
iter
(
dataset
)
optimizer
=
tf
.
keras
.
optimizers
.
SGD
(
lr
=
0.1
)
task
.
train_step
(
next
(
iterator
),
model
,
optimizer
,
metrics
=
metrics
)
task
.
validation_step
(
next
(
iterator
),
model
,
metrics
=
metrics
)
if
__name__
==
"__main__"
:
tf
.
test
.
main
()
official/vision/image_classification/resnet/resnet_ctl_imagenet_main.py
View file @
e7667f6f
...
@@ -125,7 +125,7 @@ def run(flags_obj):
...
@@ -125,7 +125,7 @@ def run(flags_obj):
per_epoch_steps
,
train_epochs
,
eval_steps
=
get_num_train_iterations
(
per_epoch_steps
,
train_epochs
,
eval_steps
=
get_num_train_iterations
(
flags_obj
)
flags_obj
)
if
flags_obj
.
steps_per_loop
is
None
:
if
not
flags_obj
.
steps_per_loop
:
steps_per_loop
=
per_epoch_steps
steps_per_loop
=
per_epoch_steps
elif
flags_obj
.
steps_per_loop
>
per_epoch_steps
:
elif
flags_obj
.
steps_per_loop
>
per_epoch_steps
:
steps_per_loop
=
per_epoch_steps
steps_per_loop
=
per_epoch_steps
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment