Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
48e49875
Commit
48e49875
authored
Aug 05, 2020
by
Hongkun Yu
Committed by
A. Unique TensorFlower
Aug 05, 2020
Browse files
Internal change
PiperOrigin-RevId: 325073831
parent
6994693d
Changes
16
Show whitespace changes
Inline
Side-by-side
Showing
16 changed files
with
219 additions
and
317 deletions
+219
-317
official/nlp/configs/bert.py
official/nlp/configs/bert.py
+5
-32
official/nlp/configs/bert_test.py
official/nlp/configs/bert_test.py
+0
-66
official/nlp/configs/electra.py
official/nlp/configs/electra.py
+4
-58
official/nlp/configs/electra_test.py
official/nlp/configs/electra_test.py
+0
-49
official/nlp/configs/encoders.py
official/nlp/configs/encoders.py
+73
-39
official/nlp/modeling/models/__init__.py
official/nlp/modeling/models/__init__.py
+1
-1
official/nlp/tasks/electra_task.py
official/nlp/tasks/electra_task.py
+42
-9
official/nlp/tasks/electra_task_test.py
official/nlp/tasks/electra_task_test.py
+10
-8
official/nlp/tasks/masked_lm.py
official/nlp/tasks/masked_lm.py
+20
-7
official/nlp/tasks/masked_lm_test.py
official/nlp/tasks/masked_lm_test.py
+4
-2
official/nlp/tasks/question_answering.py
official/nlp/tasks/question_answering.py
+4
-5
official/nlp/tasks/question_answering_test.py
official/nlp/tasks/question_answering_test.py
+38
-20
official/nlp/tasks/sentence_prediction.py
official/nlp/tasks/sentence_prediction.py
+4
-6
official/nlp/tasks/sentence_prediction_test.py
official/nlp/tasks/sentence_prediction_test.py
+9
-8
official/nlp/tasks/tagging.py
official/nlp/tasks/tagging.py
+2
-4
official/nlp/tasks/tagging_test.py
official/nlp/tasks/tagging_test.py
+3
-3
No files found.
official/nlp/configs/bert.py
View file @
48e49875
...
@@ -20,13 +20,9 @@ Includes configurations and instantiation methods.
...
@@ -20,13 +20,9 @@ Includes configurations and instantiation methods.
from
typing
import
List
,
Optional
,
Text
from
typing
import
List
,
Optional
,
Text
import
dataclasses
import
dataclasses
import
tensorflow
as
tf
from
official.modeling
import
tf_utils
from
official.modeling.hyperparams
import
base_config
from
official.modeling.hyperparams
import
base_config
from
official.nlp.configs
import
encoders
from
official.nlp.configs
import
encoders
from
official.nlp.modeling
import
layers
from
official.nlp.modeling.models
import
bert_pretrainer
@
dataclasses
.
dataclass
@
dataclasses
.
dataclass
...
@@ -40,32 +36,9 @@ class ClsHeadConfig(base_config.Config):
...
@@ -40,32 +36,9 @@ class ClsHeadConfig(base_config.Config):
@
dataclasses
.
dataclass
@
dataclasses
.
dataclass
class
BertPretrainerConfig
(
base_config
.
Config
):
class
PretrainerConfig
(
base_config
.
Config
):
"""BERT encoder configuration."""
"""Pretrainer configuration."""
encoder
:
encoders
.
TransformerEncoderConfig
=
(
encoder
:
encoders
.
EncoderConfig
=
encoders
.
EncoderConfig
()
encoders
.
TransformerEncoderConfig
())
cls_heads
:
List
[
ClsHeadConfig
]
=
dataclasses
.
field
(
default_factory
=
list
)
cls_heads
:
List
[
ClsHeadConfig
]
=
dataclasses
.
field
(
default_factory
=
list
)
mlm_activation
:
str
=
"gelu"
mlm_initializer_range
:
float
=
0.02
def
instantiate_classification_heads_from_cfgs
(
cls_head_configs
:
List
[
ClsHeadConfig
])
->
List
[
layers
.
ClassificationHead
]:
return
[
layers
.
ClassificationHead
(
**
cfg
.
as_dict
())
for
cfg
in
cls_head_configs
]
if
cls_head_configs
else
[]
def
instantiate_pretrainer_from_cfg
(
config
:
BertPretrainerConfig
,
encoder_network
:
Optional
[
tf
.
keras
.
Model
]
=
None
)
->
bert_pretrainer
.
BertPretrainerV2
:
"""Instantiates a BertPretrainer from the config."""
encoder_cfg
=
config
.
encoder
if
encoder_network
is
None
:
encoder_network
=
encoders
.
instantiate_encoder_from_cfg
(
encoder_cfg
)
return
bert_pretrainer
.
BertPretrainerV2
(
mlm_activation
=
tf_utils
.
get_activation
(
encoder_cfg
.
hidden_activation
),
mlm_initializer
=
tf
.
keras
.
initializers
.
TruncatedNormal
(
stddev
=
encoder_cfg
.
initializer_range
),
encoder_network
=
encoder_network
,
classification_heads
=
instantiate_classification_heads_from_cfgs
(
config
.
cls_heads
))
official/nlp/configs/bert_test.py
deleted
100644 → 0
View file @
6994693d
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for BERT configurations and models instantiation."""
import
tensorflow
as
tf
from
official.nlp.configs
import
bert
from
official.nlp.configs
import
encoders
class
BertModelsTest
(
tf
.
test
.
TestCase
):
def
test_network_invocation
(
self
):
config
=
bert
.
BertPretrainerConfig
(
encoder
=
encoders
.
TransformerEncoderConfig
(
vocab_size
=
10
,
num_layers
=
1
))
_
=
bert
.
instantiate_pretrainer_from_cfg
(
config
)
# Invokes with classification heads.
config
=
bert
.
BertPretrainerConfig
(
encoder
=
encoders
.
TransformerEncoderConfig
(
vocab_size
=
10
,
num_layers
=
1
),
cls_heads
=
[
bert
.
ClsHeadConfig
(
inner_dim
=
10
,
num_classes
=
2
,
name
=
"next_sentence"
)
])
_
=
bert
.
instantiate_pretrainer_from_cfg
(
config
)
with
self
.
assertRaises
(
ValueError
):
config
=
bert
.
BertPretrainerConfig
(
encoder
=
encoders
.
TransformerEncoderConfig
(
vocab_size
=
10
,
num_layers
=
1
),
cls_heads
=
[
bert
.
ClsHeadConfig
(
inner_dim
=
10
,
num_classes
=
2
,
name
=
"next_sentence"
),
bert
.
ClsHeadConfig
(
inner_dim
=
10
,
num_classes
=
2
,
name
=
"next_sentence"
)
])
_
=
bert
.
instantiate_pretrainer_from_cfg
(
config
)
def
test_checkpoint_items
(
self
):
config
=
bert
.
BertPretrainerConfig
(
encoder
=
encoders
.
TransformerEncoderConfig
(
vocab_size
=
10
,
num_layers
=
1
),
cls_heads
=
[
bert
.
ClsHeadConfig
(
inner_dim
=
10
,
num_classes
=
2
,
name
=
"next_sentence"
)
])
encoder
=
bert
.
instantiate_pretrainer_from_cfg
(
config
)
self
.
assertSameElements
(
encoder
.
checkpoint_items
.
keys
(),
[
"encoder"
,
"masked_lm"
,
"next_sentence.pooler_dense"
])
if
__name__
==
"__main__"
:
tf
.
test
.
main
()
official/nlp/configs/electra.py
View file @
48e49875
...
@@ -14,21 +14,17 @@
...
@@ -14,21 +14,17 @@
# limitations under the License.
# limitations under the License.
# ==============================================================================
# ==============================================================================
"""ELECTRA model configurations and instantiation methods."""
"""ELECTRA model configurations and instantiation methods."""
from
typing
import
List
,
Optional
from
typing
import
List
import
dataclasses
import
dataclasses
import
tensorflow
as
tf
from
official.modeling
import
tf_utils
from
official.modeling.hyperparams
import
base_config
from
official.modeling.hyperparams
import
base_config
from
official.nlp.configs
import
bert
from
official.nlp.configs
import
bert
from
official.nlp.configs
import
encoders
from
official.nlp.configs
import
encoders
from
official.nlp.modeling
import
layers
from
official.nlp.modeling.models
import
electra_pretrainer
@
dataclasses
.
dataclass
@
dataclasses
.
dataclass
class
E
LECTRA
PretrainerConfig
(
base_config
.
Config
):
class
E
lectra
PretrainerConfig
(
base_config
.
Config
):
"""ELECTRA pretrainer configuration."""
"""ELECTRA pretrainer configuration."""
num_masked_tokens
:
int
=
76
num_masked_tokens
:
int
=
76
sequence_length
:
int
=
512
sequence_length
:
int
=
512
...
@@ -36,56 +32,6 @@ class ELECTRAPretrainerConfig(base_config.Config):
...
@@ -36,56 +32,6 @@ class ELECTRAPretrainerConfig(base_config.Config):
discriminator_loss_weight
:
float
=
50.0
discriminator_loss_weight
:
float
=
50.0
tie_embeddings
:
bool
=
True
tie_embeddings
:
bool
=
True
disallow_correct
:
bool
=
False
disallow_correct
:
bool
=
False
generator_encoder
:
encoders
.
TransformerEncoderConfig
=
(
generator_encoder
:
encoders
.
EncoderConfig
=
encoders
.
EncoderConfig
()
encoders
.
TransformerEncoderConfig
())
discriminator_encoder
:
encoders
.
EncoderConfig
=
encoders
.
EncoderConfig
()
discriminator_encoder
:
encoders
.
TransformerEncoderConfig
=
(
encoders
.
TransformerEncoderConfig
())
cls_heads
:
List
[
bert
.
ClsHeadConfig
]
=
dataclasses
.
field
(
default_factory
=
list
)
cls_heads
:
List
[
bert
.
ClsHeadConfig
]
=
dataclasses
.
field
(
default_factory
=
list
)
def
instantiate_classification_heads_from_cfgs
(
cls_head_configs
:
List
[
bert
.
ClsHeadConfig
]
)
->
List
[
layers
.
ClassificationHead
]:
if
cls_head_configs
:
return
[
layers
.
ClassificationHead
(
**
cfg
.
as_dict
())
for
cfg
in
cls_head_configs
]
else
:
return
[]
def
instantiate_pretrainer_from_cfg
(
config
:
ELECTRAPretrainerConfig
,
generator_network
:
Optional
[
tf
.
keras
.
Model
]
=
None
,
discriminator_network
:
Optional
[
tf
.
keras
.
Model
]
=
None
,
)
->
electra_pretrainer
.
ElectraPretrainer
:
"""Instantiates ElectraPretrainer from the config."""
generator_encoder_cfg
=
config
.
generator_encoder
discriminator_encoder_cfg
=
config
.
discriminator_encoder
# Copy discriminator's embeddings to generator for easier model serialization.
if
discriminator_network
is
None
:
discriminator_network
=
encoders
.
instantiate_encoder_from_cfg
(
discriminator_encoder_cfg
)
if
generator_network
is
None
:
if
config
.
tie_embeddings
:
embedding_layer
=
discriminator_network
.
get_embedding_layer
()
generator_network
=
encoders
.
instantiate_encoder_from_cfg
(
generator_encoder_cfg
,
embedding_layer
=
embedding_layer
)
else
:
generator_network
=
encoders
.
instantiate_encoder_from_cfg
(
generator_encoder_cfg
)
return
electra_pretrainer
.
ElectraPretrainer
(
generator_network
=
generator_network
,
discriminator_network
=
discriminator_network
,
vocab_size
=
config
.
generator_encoder
.
vocab_size
,
num_classes
=
config
.
num_classes
,
sequence_length
=
config
.
sequence_length
,
num_token_predictions
=
config
.
num_masked_tokens
,
mlm_activation
=
tf_utils
.
get_activation
(
generator_encoder_cfg
.
hidden_activation
),
mlm_initializer
=
tf
.
keras
.
initializers
.
TruncatedNormal
(
stddev
=
generator_encoder_cfg
.
initializer_range
),
classification_heads
=
instantiate_classification_heads_from_cfgs
(
config
.
cls_heads
),
disallow_correct
=
config
.
disallow_correct
)
official/nlp/configs/electra_test.py
deleted
100644 → 0
View file @
6994693d
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for ELECTRA configurations and models instantiation."""
import
tensorflow
as
tf
from
official.nlp.configs
import
bert
from
official.nlp.configs
import
electra
from
official.nlp.configs
import
encoders
class
ELECTRAModelsTest
(
tf
.
test
.
TestCase
):
def
test_network_invocation
(
self
):
config
=
electra
.
ELECTRAPretrainerConfig
(
generator_encoder
=
encoders
.
TransformerEncoderConfig
(
vocab_size
=
10
,
num_layers
=
1
),
discriminator_encoder
=
encoders
.
TransformerEncoderConfig
(
vocab_size
=
10
,
num_layers
=
2
),
)
_
=
electra
.
instantiate_pretrainer_from_cfg
(
config
)
# Invokes with classification heads.
config
=
electra
.
ELECTRAPretrainerConfig
(
generator_encoder
=
encoders
.
TransformerEncoderConfig
(
vocab_size
=
10
,
num_layers
=
1
),
discriminator_encoder
=
encoders
.
TransformerEncoderConfig
(
vocab_size
=
10
,
num_layers
=
2
),
cls_heads
=
[
bert
.
ClsHeadConfig
(
inner_dim
=
10
,
num_classes
=
2
,
name
=
"next_sentence"
)
])
_
=
electra
.
instantiate_pretrainer_from_cfg
(
config
)
if
__name__
==
"__main__"
:
tf
.
test
.
main
()
official/nlp/configs/encoders.py
View file @
48e49875
...
@@ -15,20 +15,23 @@
...
@@ -15,20 +15,23 @@
# ==============================================================================
# ==============================================================================
"""Transformer Encoders.
"""Transformer Encoders.
Includes configurations and
instantiation
methods.
Includes configurations and
factory
methods.
"""
"""
from
typing
import
Optional
from
typing
import
Optional
from
absl
import
logging
import
dataclasses
import
dataclasses
import
gin
import
tensorflow
as
tf
import
tensorflow
as
tf
from
official.modeling
import
hyperparams
from
official.modeling
import
tf_utils
from
official.modeling
import
tf_utils
from
official.modeling.hyperparams
import
base_config
from
official.nlp.modeling
import
layers
from
official.nlp.modeling
import
layers
from
official.nlp.modeling
import
networks
from
official.nlp.modeling
import
networks
@
dataclasses
.
dataclass
@
dataclasses
.
dataclass
class
Transform
erEncoderConfig
(
base_config
.
Config
):
class
B
er
t
EncoderConfig
(
hyperparams
.
Config
):
"""BERT encoder configuration."""
"""BERT encoder configuration."""
vocab_size
:
int
=
30522
vocab_size
:
int
=
30522
hidden_size
:
int
=
768
hidden_size
:
int
=
768
...
@@ -44,55 +47,86 @@ class TransformerEncoderConfig(base_config.Config):
...
@@ -44,55 +47,86 @@ class TransformerEncoderConfig(base_config.Config):
embedding_size
:
Optional
[
int
]
=
None
embedding_size
:
Optional
[
int
]
=
None
def
instantiate_encoder_from_cfg
(
@
dataclasses
.
dataclass
config
:
TransformerEncoderConfig
,
class
EncoderConfig
(
hyperparams
.
OneOfConfig
):
encoder_cls
=
networks
.
TransformerEncoder
,
"""Encoder configuration."""
embedding_layer
:
Optional
[
layers
.
OnDeviceEmbedding
]
=
None
):
type
:
Optional
[
str
]
=
"bert"
"""Instantiate a Transformer encoder network from TransformerEncoderConfig."""
bert
:
BertEncoderConfig
=
BertEncoderConfig
()
ENCODER_CLS
=
{
"bert"
:
networks
.
TransformerEncoder
,
}
@
gin
.
configurable
def
build_encoder
(
config
:
EncoderConfig
,
embedding_layer
:
Optional
[
layers
.
OnDeviceEmbedding
]
=
None
,
encoder_cls
=
None
,
bypass_config
:
bool
=
False
):
"""Instantiate a Transformer encoder network from EncoderConfig.
Args:
config: the one-of encoder config, which provides encoder parameters of a
chosen encoder.
embedding_layer: an external embedding layer passed to the encoder.
encoder_cls: an external encoder cls not included in the supported encoders,
usually used by gin.configurable.
bypass_config: whether to ignore config instance to create the object with
`encoder_cls`.
Returns:
An encoder instance.
"""
encoder_type
=
config
.
type
encoder_cfg
=
config
.
get
()
encoder_cls
=
encoder_cls
or
ENCODER_CLS
[
encoder_type
]
logging
.
info
(
"Encoder class: %s to build..."
,
encoder_cls
.
__name__
)
if
bypass_config
:
return
encoder_cls
()
if
encoder_cls
.
__name__
==
"EncoderScaffold"
:
if
encoder_cls
.
__name__
==
"EncoderScaffold"
:
embedding_cfg
=
dict
(
embedding_cfg
=
dict
(
vocab_size
=
confi
g
.
vocab_size
,
vocab_size
=
encoder_cf
g
.
vocab_size
,
type_vocab_size
=
confi
g
.
type_vocab_size
,
type_vocab_size
=
encoder_cf
g
.
type_vocab_size
,
hidden_size
=
confi
g
.
hidden_size
,
hidden_size
=
encoder_cf
g
.
hidden_size
,
max_seq_length
=
confi
g
.
max_position_embeddings
,
max_seq_length
=
encoder_cf
g
.
max_position_embeddings
,
initializer
=
tf
.
keras
.
initializers
.
TruncatedNormal
(
initializer
=
tf
.
keras
.
initializers
.
TruncatedNormal
(
stddev
=
confi
g
.
initializer_range
),
stddev
=
encoder_cf
g
.
initializer_range
),
dropout_rate
=
confi
g
.
dropout_rate
,
dropout_rate
=
encoder_cf
g
.
dropout_rate
,
)
)
hidden_cfg
=
dict
(
hidden_cfg
=
dict
(
num_attention_heads
=
confi
g
.
num_attention_heads
,
num_attention_heads
=
encoder_cf
g
.
num_attention_heads
,
intermediate_size
=
confi
g
.
intermediate_size
,
intermediate_size
=
encoder_cf
g
.
intermediate_size
,
intermediate_activation
=
tf_utils
.
get_activation
(
intermediate_activation
=
tf_utils
.
get_activation
(
confi
g
.
hidden_activation
),
encoder_cf
g
.
hidden_activation
),
dropout_rate
=
confi
g
.
dropout_rate
,
dropout_rate
=
encoder_cf
g
.
dropout_rate
,
attention_dropout_rate
=
confi
g
.
attention_dropout_rate
,
attention_dropout_rate
=
encoder_cf
g
.
attention_dropout_rate
,
kernel_initializer
=
tf
.
keras
.
initializers
.
TruncatedNormal
(
kernel_initializer
=
tf
.
keras
.
initializers
.
TruncatedNormal
(
stddev
=
confi
g
.
initializer_range
),
stddev
=
encoder_cf
g
.
initializer_range
),
)
)
kwargs
=
dict
(
kwargs
=
dict
(
embedding_cfg
=
embedding_cfg
,
embedding_cfg
=
embedding_cfg
,
hidden_cfg
=
hidden_cfg
,
hidden_cfg
=
hidden_cfg
,
num_hidden_instances
=
confi
g
.
num_layers
,
num_hidden_instances
=
encoder_cf
g
.
num_layers
,
pooled_output_dim
=
confi
g
.
hidden_size
,
pooled_output_dim
=
encoder_cf
g
.
hidden_size
,
pooler_layer_initializer
=
tf
.
keras
.
initializers
.
TruncatedNormal
(
pooler_layer_initializer
=
tf
.
keras
.
initializers
.
TruncatedNormal
(
stddev
=
confi
g
.
initializer_range
))
stddev
=
encoder_cf
g
.
initializer_range
))
return
encoder_cls
(
**
kwargs
)
return
encoder_cls
(
**
kwargs
)
if
encoder_cls
.
__name__
!=
"TransformerE
ncoder
"
:
# Uses the default BERTEncoder configuration schema to create the e
ncoder
.
raise
ValueError
(
"Unknown encoder network class. %s"
%
str
(
encoder
_cls
))
# If it does not match, please add a switch branch by the
encoder
type.
encoder_network
=
encoder_cls
(
return
encoder_cls
(
vocab_size
=
confi
g
.
vocab_size
,
vocab_size
=
encoder_cf
g
.
vocab_size
,
hidden_size
=
confi
g
.
hidden_size
,
hidden_size
=
encoder_cf
g
.
hidden_size
,
num_layers
=
confi
g
.
num_layers
,
num_layers
=
encoder_cf
g
.
num_layers
,
num_attention_heads
=
confi
g
.
num_attention_heads
,
num_attention_heads
=
encoder_cf
g
.
num_attention_heads
,
intermediate_size
=
confi
g
.
intermediate_size
,
intermediate_size
=
encoder_cf
g
.
intermediate_size
,
activation
=
tf_utils
.
get_activation
(
confi
g
.
hidden_activation
),
activation
=
tf_utils
.
get_activation
(
encoder_cf
g
.
hidden_activation
),
dropout_rate
=
confi
g
.
dropout_rate
,
dropout_rate
=
encoder_cf
g
.
dropout_rate
,
attention_dropout_rate
=
confi
g
.
attention_dropout_rate
,
attention_dropout_rate
=
encoder_cf
g
.
attention_dropout_rate
,
max_sequence_length
=
confi
g
.
max_position_embeddings
,
max_sequence_length
=
encoder_cf
g
.
max_position_embeddings
,
type_vocab_size
=
confi
g
.
type_vocab_size
,
type_vocab_size
=
encoder_cf
g
.
type_vocab_size
,
initializer
=
tf
.
keras
.
initializers
.
TruncatedNormal
(
initializer
=
tf
.
keras
.
initializers
.
TruncatedNormal
(
stddev
=
confi
g
.
initializer_range
),
stddev
=
encoder_cf
g
.
initializer_range
),
embedding_width
=
confi
g
.
embedding_size
,
embedding_width
=
encoder_cf
g
.
embedding_size
,
embedding_layer
=
embedding_layer
)
embedding_layer
=
embedding_layer
)
return
encoder_network
official/nlp/modeling/models/__init__.py
View file @
48e49875
...
@@ -14,7 +14,7 @@
...
@@ -14,7 +14,7 @@
# ==============================================================================
# ==============================================================================
"""Models package definition."""
"""Models package definition."""
from
official.nlp.modeling.models.bert_classifier
import
BertClassifier
from
official.nlp.modeling.models.bert_classifier
import
BertClassifier
from
official.nlp.modeling.models.bert_pretrainer
import
BertPretrainer
from
official.nlp.modeling.models.bert_pretrainer
import
*
from
official.nlp.modeling.models.bert_span_labeler
import
BertSpanLabeler
from
official.nlp.modeling.models.bert_span_labeler
import
BertSpanLabeler
from
official.nlp.modeling.models.bert_token_classifier
import
BertTokenClassifier
from
official.nlp.modeling.models.bert_token_classifier
import
BertTokenClassifier
from
official.nlp.modeling.models.electra_pretrainer
import
ElectraPretrainer
from
official.nlp.modeling.models.electra_pretrainer
import
ElectraPretrainer
official/nlp/tasks/electra_task.py
View file @
48e49875
...
@@ -19,16 +19,20 @@ import tensorflow as tf
...
@@ -19,16 +19,20 @@ import tensorflow as tf
from
official.core
import
base_task
from
official.core
import
base_task
from
official.core
import
task_factory
from
official.core
import
task_factory
from
official.modeling
import
tf_utils
from
official.modeling.hyperparams
import
config_definitions
as
cfg
from
official.modeling.hyperparams
import
config_definitions
as
cfg
from
official.nlp.configs
import
bert
from
official.nlp.configs
import
bert
from
official.nlp.configs
import
electra
from
official.nlp.configs
import
electra
from
official.nlp.configs
import
encoders
from
official.nlp.data
import
pretrain_dataloader
from
official.nlp.data
import
pretrain_dataloader
from
official.nlp.modeling
import
layers
from
official.nlp.modeling
import
models
@
dataclasses
.
dataclass
@
dataclasses
.
dataclass
class
E
LECTRA
PretrainConfig
(
cfg
.
TaskConfig
):
class
E
lectra
PretrainConfig
(
cfg
.
TaskConfig
):
"""The model config."""
"""The model config."""
model
:
electra
.
E
LECTRA
PretrainerConfig
=
electra
.
E
LECTRA
PretrainerConfig
(
model
:
electra
.
E
lectra
PretrainerConfig
=
electra
.
E
lectra
PretrainerConfig
(
cls_heads
=
[
cls_heads
=
[
bert
.
ClsHeadConfig
(
bert
.
ClsHeadConfig
(
inner_dim
=
768
,
inner_dim
=
768
,
...
@@ -40,13 +44,44 @@ class ELECTRAPretrainConfig(cfg.TaskConfig):
...
@@ -40,13 +44,44 @@ class ELECTRAPretrainConfig(cfg.TaskConfig):
validation_data
:
cfg
.
DataConfig
=
cfg
.
DataConfig
()
validation_data
:
cfg
.
DataConfig
=
cfg
.
DataConfig
()
@
task_factory
.
register_task_cls
(
ELECTRAPretrainConfig
)
def
_build_pretrainer
(
class
ELECTRAPretrainTask
(
base_task
.
Task
):
config
:
electra
.
ElectraPretrainerConfig
)
->
models
.
ElectraPretrainer
:
"""Instantiates ElectraPretrainer from the config."""
generator_encoder_cfg
=
config
.
generator_encoder
discriminator_encoder_cfg
=
config
.
discriminator_encoder
# Copy discriminator's embeddings to generator for easier model serialization.
discriminator_network
=
encoders
.
build_encoder
(
discriminator_encoder_cfg
)
if
config
.
tie_embeddings
:
embedding_layer
=
discriminator_network
.
get_embedding_layer
()
generator_network
=
encoders
.
build_encoder
(
generator_encoder_cfg
,
embedding_layer
=
embedding_layer
)
else
:
generator_network
=
encoders
.
build_encoder
(
generator_encoder_cfg
)
generator_encoder_cfg
=
generator_encoder_cfg
.
get
()
return
models
.
ElectraPretrainer
(
generator_network
=
generator_network
,
discriminator_network
=
discriminator_network
,
vocab_size
=
generator_encoder_cfg
.
vocab_size
,
num_classes
=
config
.
num_classes
,
sequence_length
=
config
.
sequence_length
,
num_token_predictions
=
config
.
num_masked_tokens
,
mlm_activation
=
tf_utils
.
get_activation
(
generator_encoder_cfg
.
hidden_activation
),
mlm_initializer
=
tf
.
keras
.
initializers
.
TruncatedNormal
(
stddev
=
generator_encoder_cfg
.
initializer_range
),
classification_heads
=
[
layers
.
ClassificationHead
(
**
cfg
.
as_dict
())
for
cfg
in
config
.
cls_heads
],
disallow_correct
=
config
.
disallow_correct
)
@
task_factory
.
register_task_cls
(
ElectraPretrainConfig
)
class
ElectraPretrainTask
(
base_task
.
Task
):
"""ELECTRA Pretrain Task (Masked LM + Replaced Token Detection)."""
"""ELECTRA Pretrain Task (Masked LM + Replaced Token Detection)."""
def
build_model
(
self
):
def
build_model
(
self
):
return
electra
.
instantiate_pretrainer_from_cfg
(
return
_build_pretrainer
(
self
.
task_config
.
model
)
self
.
task_config
.
model
)
def
build_losses
(
self
,
def
build_losses
(
self
,
labels
,
labels
,
...
@@ -70,9 +105,7 @@ class ELECTRAPretrainTask(base_task.Task):
...
@@ -70,9 +105,7 @@ class ELECTRAPretrainTask(base_task.Task):
sentence_outputs
=
tf
.
cast
(
sentence_outputs
=
tf
.
cast
(
model_outputs
[
'sentence_outputs'
],
dtype
=
tf
.
float32
)
model_outputs
[
'sentence_outputs'
],
dtype
=
tf
.
float32
)
sentence_loss
=
tf
.
keras
.
losses
.
sparse_categorical_crossentropy
(
sentence_loss
=
tf
.
keras
.
losses
.
sparse_categorical_crossentropy
(
sentence_labels
,
sentence_labels
,
sentence_outputs
,
from_logits
=
True
)
sentence_outputs
,
from_logits
=
True
)
metrics
[
'next_sentence_loss'
].
update_state
(
sentence_loss
)
metrics
[
'next_sentence_loss'
].
update_state
(
sentence_loss
)
total_loss
=
mlm_loss
+
sentence_loss
total_loss
=
mlm_loss
+
sentence_loss
else
:
else
:
...
...
official/nlp/tasks/electra_task_test.py
View file @
48e49875
...
@@ -24,15 +24,17 @@ from official.nlp.data import pretrain_dataloader
...
@@ -24,15 +24,17 @@ from official.nlp.data import pretrain_dataloader
from
official.nlp.tasks
import
electra_task
from
official.nlp.tasks
import
electra_task
class
E
LECTRA
PretrainTaskTest
(
tf
.
test
.
TestCase
):
class
E
lectra
PretrainTaskTest
(
tf
.
test
.
TestCase
):
def
test_task
(
self
):
def
test_task
(
self
):
config
=
electra_task
.
ELECTRAPretrainConfig
(
config
=
electra_task
.
ElectraPretrainConfig
(
model
=
electra
.
ELECTRAPretrainerConfig
(
model
=
electra
.
ElectraPretrainerConfig
(
generator_encoder
=
encoders
.
TransformerEncoderConfig
(
generator_encoder
=
encoders
.
EncoderConfig
(
vocab_size
=
30522
,
num_layers
=
1
),
bert
=
encoders
.
BertEncoderConfig
(
vocab_size
=
30522
,
discriminator_encoder
=
encoders
.
TransformerEncoderConfig
(
num_layers
=
1
)),
vocab_size
=
30522
,
num_layers
=
1
),
discriminator_encoder
=
encoders
.
EncoderConfig
(
bert
=
encoders
.
BertEncoderConfig
(
vocab_size
=
30522
,
num_layers
=
1
)),
num_masked_tokens
=
20
,
num_masked_tokens
=
20
,
sequence_length
=
128
,
sequence_length
=
128
,
cls_heads
=
[
cls_heads
=
[
...
@@ -44,7 +46,7 @@ class ELECTRAPretrainTaskTest(tf.test.TestCase):
...
@@ -44,7 +46,7 @@ class ELECTRAPretrainTaskTest(tf.test.TestCase):
max_predictions_per_seq
=
20
,
max_predictions_per_seq
=
20
,
seq_length
=
128
,
seq_length
=
128
,
global_batch_size
=
1
))
global_batch_size
=
1
))
task
=
electra_task
.
E
LECTRA
PretrainTask
(
config
)
task
=
electra_task
.
E
lectra
PretrainTask
(
config
)
model
=
task
.
build_model
()
model
=
task
.
build_model
()
metrics
=
task
.
build_metrics
()
metrics
=
task
.
build_metrics
()
dataset
=
task
.
build_inputs
(
config
.
train_data
)
dataset
=
task
.
build_inputs
(
config
.
train_data
)
...
...
official/nlp/tasks/masked_lm.py
View file @
48e49875
...
@@ -19,15 +19,19 @@ import tensorflow as tf
...
@@ -19,15 +19,19 @@ import tensorflow as tf
from
official.core
import
base_task
from
official.core
import
base_task
from
official.core
import
task_factory
from
official.core
import
task_factory
from
official.modeling
import
tf_utils
from
official.modeling.hyperparams
import
config_definitions
as
cfg
from
official.modeling.hyperparams
import
config_definitions
as
cfg
from
official.nlp.configs
import
bert
from
official.nlp.configs
import
bert
from
official.nlp.configs
import
encoders
from
official.nlp.data
import
data_loader_factory
from
official.nlp.data
import
data_loader_factory
from
official.nlp.modeling
import
layers
from
official.nlp.modeling
import
models
@
dataclasses
.
dataclass
@
dataclasses
.
dataclass
class
MaskedLMConfig
(
cfg
.
TaskConfig
):
class
MaskedLMConfig
(
cfg
.
TaskConfig
):
"""The model config."""
"""The model config."""
model
:
bert
.
Bert
PretrainerConfig
=
bert
.
Bert
PretrainerConfig
(
cls_heads
=
[
model
:
bert
.
PretrainerConfig
=
bert
.
PretrainerConfig
(
cls_heads
=
[
bert
.
ClsHeadConfig
(
bert
.
ClsHeadConfig
(
inner_dim
=
768
,
num_classes
=
2
,
dropout_rate
=
0.1
,
name
=
'next_sentence'
)
inner_dim
=
768
,
num_classes
=
2
,
dropout_rate
=
0.1
,
name
=
'next_sentence'
)
])
])
...
@@ -37,11 +41,21 @@ class MaskedLMConfig(cfg.TaskConfig):
...
@@ -37,11 +41,21 @@ class MaskedLMConfig(cfg.TaskConfig):
@
task_factory
.
register_task_cls
(
MaskedLMConfig
)
@
task_factory
.
register_task_cls
(
MaskedLMConfig
)
class
MaskedLMTask
(
base_task
.
Task
):
class
MaskedLMTask
(
base_task
.
Task
):
"""
Mock t
ask object for
test
ing."""
"""
T
ask object for
Mask language model
ing."""
def
build_model
(
self
,
params
=
None
):
def
build_model
(
self
,
params
=
None
):
params
=
params
or
self
.
task_config
.
model
config
=
params
or
self
.
task_config
.
model
return
bert
.
instantiate_pretrainer_from_cfg
(
params
)
encoder_cfg
=
config
.
encoder
encoder_network
=
encoders
.
build_encoder
(
encoder_cfg
)
cls_heads
=
[
layers
.
ClassificationHead
(
**
cfg
.
as_dict
())
for
cfg
in
config
.
cls_heads
]
if
config
.
cls_heads
else
[]
return
models
.
BertPretrainerV2
(
mlm_activation
=
tf_utils
.
get_activation
(
config
.
mlm_activation
),
mlm_initializer
=
tf
.
keras
.
initializers
.
TruncatedNormal
(
stddev
=
config
.
mlm_initializer_range
),
encoder_network
=
encoder_network
,
classification_heads
=
cls_heads
)
def
build_losses
(
self
,
def
build_losses
(
self
,
labels
,
labels
,
...
@@ -63,9 +77,8 @@ class MaskedLMTask(base_task.Task):
...
@@ -63,9 +77,8 @@ class MaskedLMTask(base_task.Task):
sentence_outputs
=
tf
.
cast
(
sentence_outputs
=
tf
.
cast
(
model_outputs
[
'next_sentence'
],
dtype
=
tf
.
float32
)
model_outputs
[
'next_sentence'
],
dtype
=
tf
.
float32
)
sentence_loss
=
tf
.
reduce_mean
(
sentence_loss
=
tf
.
reduce_mean
(
tf
.
keras
.
losses
.
sparse_categorical_crossentropy
(
sentence_labels
,
tf
.
keras
.
losses
.
sparse_categorical_crossentropy
(
sentence_outputs
,
sentence_labels
,
sentence_outputs
,
from_logits
=
True
))
from_logits
=
True
))
metrics
[
'next_sentence_loss'
].
update_state
(
sentence_loss
)
metrics
[
'next_sentence_loss'
].
update_state
(
sentence_loss
)
total_loss
=
mlm_loss
+
sentence_loss
total_loss
=
mlm_loss
+
sentence_loss
else
:
else
:
...
...
official/nlp/tasks/masked_lm_test.py
View file @
48e49875
...
@@ -28,8 +28,10 @@ class MLMTaskTest(tf.test.TestCase):
...
@@ -28,8 +28,10 @@ class MLMTaskTest(tf.test.TestCase):
def
test_task
(
self
):
def
test_task
(
self
):
config
=
masked_lm
.
MaskedLMConfig
(
config
=
masked_lm
.
MaskedLMConfig
(
init_checkpoint
=
self
.
get_temp_dir
(),
init_checkpoint
=
self
.
get_temp_dir
(),
model
=
bert
.
BertPretrainerConfig
(
model
=
bert
.
PretrainerConfig
(
encoders
.
TransformerEncoderConfig
(
vocab_size
=
30522
,
num_layers
=
1
),
encoders
.
EncoderConfig
(
bert
=
encoders
.
BertEncoderConfig
(
vocab_size
=
30522
,
num_layers
=
1
)),
cls_heads
=
[
cls_heads
=
[
bert
.
ClsHeadConfig
(
bert
.
ClsHeadConfig
(
inner_dim
=
10
,
num_classes
=
2
,
name
=
"next_sentence"
)
inner_dim
=
10
,
num_classes
=
2
,
name
=
"next_sentence"
)
...
...
official/nlp/tasks/question_answering.py
View file @
48e49875
...
@@ -42,8 +42,7 @@ from official.nlp.tasks import utils
...
@@ -42,8 +42,7 @@ from official.nlp.tasks import utils
@
dataclasses
.
dataclass
@
dataclasses
.
dataclass
class
ModelConfig
(
base_config
.
Config
):
class
ModelConfig
(
base_config
.
Config
):
"""A base span labeler configuration."""
"""A base span labeler configuration."""
encoder
:
encoders
.
TransformerEncoderConfig
=
(
encoder
:
encoders
.
EncoderConfig
=
encoders
.
EncoderConfig
()
encoders
.
TransformerEncoderConfig
())
@
dataclasses
.
dataclass
@
dataclasses
.
dataclass
...
@@ -94,13 +93,13 @@ class QuestionAnsweringTask(base_task.Task):
...
@@ -94,13 +93,13 @@ class QuestionAnsweringTask(base_task.Task):
if
self
.
_hub_module
:
if
self
.
_hub_module
:
encoder_network
=
utils
.
get_encoder_from_hub
(
self
.
_hub_module
)
encoder_network
=
utils
.
get_encoder_from_hub
(
self
.
_hub_module
)
else
:
else
:
encoder_network
=
encoders
.
instantiate_encoder_from_cfg
(
encoder_network
=
encoders
.
build_encoder
(
self
.
task_config
.
model
.
encoder
)
self
.
task_config
.
model
.
encoder
)
encoder_cfg
=
self
.
task_config
.
model
.
encoder
.
get
(
)
# Currently, we only supports bert-style question answering finetuning.
# Currently, we only supports bert-style question answering finetuning.
return
models
.
BertSpanLabeler
(
return
models
.
BertSpanLabeler
(
network
=
encoder_network
,
network
=
encoder_network
,
initializer
=
tf
.
keras
.
initializers
.
TruncatedNormal
(
initializer
=
tf
.
keras
.
initializers
.
TruncatedNormal
(
stddev
=
self
.
task_config
.
model
.
encoder
.
initializer_range
))
stddev
=
encoder
_cfg
.
initializer_range
))
def
build_losses
(
self
,
labels
,
model_outputs
,
aux_losses
=
None
)
->
tf
.
Tensor
:
def
build_losses
(
self
,
labels
,
model_outputs
,
aux_losses
=
None
)
->
tf
.
Tensor
:
start_positions
=
labels
[
'start_positions'
]
start_positions
=
labels
[
'start_positions'
]
...
...
official/nlp/tasks/question_answering_test.py
View file @
48e49875
...
@@ -25,6 +25,7 @@ from official.nlp.bert import export_tfhub
...
@@ -25,6 +25,7 @@ from official.nlp.bert import export_tfhub
from
official.nlp.configs
import
bert
from
official.nlp.configs
import
bert
from
official.nlp.configs
import
encoders
from
official.nlp.configs
import
encoders
from
official.nlp.data
import
question_answering_dataloader
from
official.nlp.data
import
question_answering_dataloader
from
official.nlp.tasks
import
masked_lm
from
official.nlp.tasks
import
question_answering
from
official.nlp.tasks
import
question_answering
...
@@ -32,21 +33,37 @@ class QuestionAnsweringTaskTest(tf.test.TestCase, parameterized.TestCase):
...
@@ -32,21 +33,37 @@ class QuestionAnsweringTaskTest(tf.test.TestCase, parameterized.TestCase):
def
setUp
(
self
):
def
setUp
(
self
):
super
(
QuestionAnsweringTaskTest
,
self
).
setUp
()
super
(
QuestionAnsweringTaskTest
,
self
).
setUp
()
self
.
_encoder_config
=
encoders
.
Transformer
EncoderConfig
(
self
.
_encoder_config
=
encoders
.
EncoderConfig
(
vocab_size
=
30522
,
num_layers
=
1
)
bert
=
encoders
.
BertEncoderConfig
(
vocab_size
=
30522
,
num_layers
=
1
)
)
self
.
_train_data_config
=
question_answering_dataloader
.
QADataConfig
(
self
.
_train_data_config
=
question_answering_dataloader
.
QADataConfig
(
input_path
=
"dummy"
,
input_path
=
"dummy"
,
seq_length
=
128
,
global_batch_size
=
1
)
seq_length
=
128
,
global_batch_size
=
1
)
val_data
=
{
"version"
:
val_data
=
{
"version"
:
"1.1"
,
"1.1"
,
"data"
:
[{
"paragraphs"
:
[
"data"
:
[{
{
"context"
:
"Sky is blue."
,
"paragraphs"
:
[{
"qas"
:
[{
"question"
:
"What is blue?"
,
"id"
:
"1234"
,
"context"
:
"answers"
:
[{
"text"
:
"Sky"
,
"answer_start"
:
0
},
"Sky is blue."
,
{
"text"
:
"Sky"
,
"answer_start"
:
0
},
"qas"
:
[{
{
"text"
:
"Sky"
,
"answer_start"
:
0
}]
"question"
:
}]}]}]}
"What is blue?"
,
"id"
:
"1234"
,
"answers"
:
[{
"text"
:
"Sky"
,
"answer_start"
:
0
},
{
"text"
:
"Sky"
,
"answer_start"
:
0
},
{
"text"
:
"Sky"
,
"answer_start"
:
0
}]
}]
}]
}]
}
self
.
_val_input_path
=
os
.
path
.
join
(
self
.
get_temp_dir
(),
"val_data.json"
)
self
.
_val_input_path
=
os
.
path
.
join
(
self
.
get_temp_dir
(),
"val_data.json"
)
with
tf
.
io
.
gfile
.
GFile
(
self
.
_val_input_path
,
"w"
)
as
writer
:
with
tf
.
io
.
gfile
.
GFile
(
self
.
_val_input_path
,
"w"
)
as
writer
:
writer
.
write
(
json
.
dumps
(
val_data
,
indent
=
4
)
+
"
\n
"
)
writer
.
write
(
json
.
dumps
(
val_data
,
indent
=
4
)
+
"
\n
"
)
...
@@ -87,19 +104,20 @@ class QuestionAnsweringTaskTest(tf.test.TestCase, parameterized.TestCase):
...
@@ -87,19 +104,20 @@ class QuestionAnsweringTaskTest(tf.test.TestCase, parameterized.TestCase):
metrics
=
task
.
reduce_aggregated_logs
(
logs
)
metrics
=
task
.
reduce_aggregated_logs
(
logs
)
self
.
assertIn
(
"final_f1"
,
metrics
)
self
.
assertIn
(
"final_f1"
,
metrics
)
@
parameterized
.
parameters
(
itertools
.
product
(
@
parameterized
.
parameters
(
itertools
.
product
(
(
False
,
True
),
(
False
,
True
),
(
"WordPiece"
,
"SentencePiece"
),
(
"WordPiece"
,
"SentencePiece"
),
))
))
def
test_task
(
self
,
version_2_with_negative
,
tokenization
):
def
test_task
(
self
,
version_2_with_negative
,
tokenization
):
# Saves a checkpoint.
# Saves a checkpoint.
pretrain_cfg
=
bert
.
Bert
PretrainerConfig
(
pretrain_cfg
=
bert
.
PretrainerConfig
(
encoder
=
self
.
_encoder_config
,
encoder
=
self
.
_encoder_config
,
cls_heads
=
[
cls_heads
=
[
bert
.
ClsHeadConfig
(
bert
.
ClsHeadConfig
(
inner_dim
=
10
,
num_classes
=
3
,
name
=
"next_sentence"
)
inner_dim
=
10
,
num_classes
=
3
,
name
=
"next_sentence"
)
])
])
pretrain_model
=
bert
.
instantiate_pretrainer_from_cfg
(
pretrain_cfg
)
pretrain_model
=
masked_lm
.
MaskedLMTask
(
None
).
build_model
(
pretrain_cfg
)
ckpt
=
tf
.
train
.
Checkpoint
(
ckpt
=
tf
.
train
.
Checkpoint
(
model
=
pretrain_model
,
**
pretrain_model
.
checkpoint_items
)
model
=
pretrain_model
,
**
pretrain_model
.
checkpoint_items
)
saved_path
=
ckpt
.
save
(
self
.
get_temp_dir
())
saved_path
=
ckpt
.
save
(
self
.
get_temp_dir
())
...
...
official/nlp/tasks/sentence_prediction.py
View file @
48e49875
...
@@ -44,8 +44,7 @@ class ModelConfig(base_config.Config):
...
@@ -44,8 +44,7 @@ class ModelConfig(base_config.Config):
"""A classifier/regressor configuration."""
"""A classifier/regressor configuration."""
num_classes
:
int
=
0
num_classes
:
int
=
0
use_encoder_pooler
:
bool
=
False
use_encoder_pooler
:
bool
=
False
encoder
:
encoders
.
TransformerEncoderConfig
=
(
encoder
:
encoders
.
EncoderConfig
=
encoders
.
EncoderConfig
()
encoders
.
TransformerEncoderConfig
())
@
dataclasses
.
dataclass
@
dataclasses
.
dataclass
...
@@ -85,15 +84,14 @@ class SentencePredictionTask(base_task.Task):
...
@@ -85,15 +84,14 @@ class SentencePredictionTask(base_task.Task):
if
self
.
_hub_module
:
if
self
.
_hub_module
:
encoder_network
=
utils
.
get_encoder_from_hub
(
self
.
_hub_module
)
encoder_network
=
utils
.
get_encoder_from_hub
(
self
.
_hub_module
)
else
:
else
:
encoder_network
=
encoders
.
instantiate_encoder_from_cfg
(
encoder_network
=
encoders
.
build_encoder
(
self
.
task_config
.
model
.
encoder
)
self
.
task_config
.
model
.
encoder
)
encoder_cfg
=
self
.
task_config
.
model
.
encoder
.
get
()
# Currently, we only support bert-style sentence prediction finetuning.
# Currently, we only support bert-style sentence prediction finetuning.
return
models
.
BertClassifier
(
return
models
.
BertClassifier
(
network
=
encoder_network
,
network
=
encoder_network
,
num_classes
=
self
.
task_config
.
model
.
num_classes
,
num_classes
=
self
.
task_config
.
model
.
num_classes
,
initializer
=
tf
.
keras
.
initializers
.
TruncatedNormal
(
initializer
=
tf
.
keras
.
initializers
.
TruncatedNormal
(
stddev
=
self
.
task_config
.
model
.
encoder
.
initializer_range
),
stddev
=
encoder
_cfg
.
initializer_range
),
use_encoder_pooler
=
self
.
task_config
.
model
.
use_encoder_pooler
)
use_encoder_pooler
=
self
.
task_config
.
model
.
use_encoder_pooler
)
def
build_losses
(
self
,
labels
,
model_outputs
,
aux_losses
=
None
)
->
tf
.
Tensor
:
def
build_losses
(
self
,
labels
,
model_outputs
,
aux_losses
=
None
)
->
tf
.
Tensor
:
...
...
official/nlp/tasks/sentence_prediction_test.py
View file @
48e49875
...
@@ -26,6 +26,7 @@ from official.nlp.bert import export_tfhub
...
@@ -26,6 +26,7 @@ from official.nlp.bert import export_tfhub
from
official.nlp.configs
import
bert
from
official.nlp.configs
import
bert
from
official.nlp.configs
import
encoders
from
official.nlp.configs
import
encoders
from
official.nlp.data
import
sentence_prediction_dataloader
from
official.nlp.data
import
sentence_prediction_dataloader
from
official.nlp.tasks
import
masked_lm
from
official.nlp.tasks
import
sentence_prediction
from
official.nlp.tasks
import
sentence_prediction
...
@@ -68,8 +69,8 @@ class SentencePredictionTaskTest(tf.test.TestCase, parameterized.TestCase):
...
@@ -68,8 +69,8 @@ class SentencePredictionTaskTest(tf.test.TestCase, parameterized.TestCase):
def
get_model_config
(
self
,
num_classes
):
def
get_model_config
(
self
,
num_classes
):
return
sentence_prediction
.
ModelConfig
(
return
sentence_prediction
.
ModelConfig
(
encoder
=
encoders
.
Transformer
EncoderConfig
(
encoder
=
encoders
.
EncoderConfig
(
vocab_size
=
30522
,
num_layers
=
1
),
bert
=
encoders
.
BertEncoderConfig
(
vocab_size
=
30522
,
num_layers
=
1
)
)
,
num_classes
=
num_classes
)
num_classes
=
num_classes
)
def
_run_task
(
self
,
config
):
def
_run_task
(
self
,
config
):
...
@@ -102,14 +103,14 @@ class SentencePredictionTaskTest(tf.test.TestCase, parameterized.TestCase):
...
@@ -102,14 +103,14 @@ class SentencePredictionTaskTest(tf.test.TestCase, parameterized.TestCase):
task
.
validation_step
(
next
(
iterator
),
model
,
metrics
=
metrics
)
task
.
validation_step
(
next
(
iterator
),
model
,
metrics
=
metrics
)
# Saves a checkpoint.
# Saves a checkpoint.
pretrain_cfg
=
bert
.
Bert
PretrainerConfig
(
pretrain_cfg
=
bert
.
PretrainerConfig
(
encoder
=
encoders
.
Transformer
EncoderConfig
(
encoder
=
encoders
.
EncoderConfig
(
vocab_size
=
30522
,
num_layers
=
1
),
bert
=
encoders
.
BertEncoderConfig
(
vocab_size
=
30522
,
num_layers
=
1
)
)
,
cls_heads
=
[
cls_heads
=
[
bert
.
ClsHeadConfig
(
bert
.
ClsHeadConfig
(
inner_dim
=
10
,
num_classes
=
3
,
name
=
"next_sentence"
)
inner_dim
=
10
,
num_classes
=
3
,
name
=
"next_sentence"
)
])
])
pretrain_model
=
bert
.
instantiate_pretrainer_from_cfg
(
pretrain_cfg
)
pretrain_model
=
masked_lm
.
MaskedLMTask
(
None
).
build_model
(
pretrain_cfg
)
ckpt
=
tf
.
train
.
Checkpoint
(
ckpt
=
tf
.
train
.
Checkpoint
(
model
=
pretrain_model
,
**
pretrain_model
.
checkpoint_items
)
model
=
pretrain_model
,
**
pretrain_model
.
checkpoint_items
)
ckpt
.
save
(
config
.
init_checkpoint
)
ckpt
.
save
(
config
.
init_checkpoint
)
...
@@ -136,8 +137,8 @@ class SentencePredictionTaskTest(tf.test.TestCase, parameterized.TestCase):
...
@@ -136,8 +137,8 @@ class SentencePredictionTaskTest(tf.test.TestCase, parameterized.TestCase):
if
num_classes
==
1
:
if
num_classes
==
1
:
self
.
assertIsInstance
(
metrics
[
0
],
tf
.
keras
.
metrics
.
MeanSquaredError
)
self
.
assertIsInstance
(
metrics
[
0
],
tf
.
keras
.
metrics
.
MeanSquaredError
)
else
:
else
:
self
.
assertIsInstance
(
self
.
assertIsInstance
(
metrics
[
0
],
metrics
[
0
],
tf
.
keras
.
metrics
.
SparseCategoricalAccuracy
)
tf
.
keras
.
metrics
.
SparseCategoricalAccuracy
)
dataset
=
task
.
build_inputs
(
config
.
train_data
)
dataset
=
task
.
build_inputs
(
config
.
train_data
)
iterator
=
iter
(
dataset
)
iterator
=
iter
(
dataset
)
...
...
official/nlp/tasks/tagging.py
View file @
48e49875
...
@@ -37,8 +37,7 @@ from official.nlp.tasks import utils
...
@@ -37,8 +37,7 @@ from official.nlp.tasks import utils
@
dataclasses
.
dataclass
@
dataclasses
.
dataclass
class
ModelConfig
(
base_config
.
Config
):
class
ModelConfig
(
base_config
.
Config
):
"""A base span labeler configuration."""
"""A base span labeler configuration."""
encoder
:
encoders
.
TransformerEncoderConfig
=
(
encoder
:
encoders
.
EncoderConfig
=
encoders
.
EncoderConfig
()
encoders
.
TransformerEncoderConfig
())
head_dropout
:
float
=
0.1
head_dropout
:
float
=
0.1
head_initializer_range
:
float
=
0.02
head_initializer_range
:
float
=
0.02
...
@@ -102,8 +101,7 @@ class TaggingTask(base_task.Task):
...
@@ -102,8 +101,7 @@ class TaggingTask(base_task.Task):
if
self
.
_hub_module
:
if
self
.
_hub_module
:
encoder_network
=
utils
.
get_encoder_from_hub
(
self
.
_hub_module
)
encoder_network
=
utils
.
get_encoder_from_hub
(
self
.
_hub_module
)
else
:
else
:
encoder_network
=
encoders
.
instantiate_encoder_from_cfg
(
encoder_network
=
encoders
.
build_encoder
(
self
.
task_config
.
model
.
encoder
)
self
.
task_config
.
model
.
encoder
)
return
models
.
BertTokenClassifier
(
return
models
.
BertTokenClassifier
(
network
=
encoder_network
,
network
=
encoder_network
,
...
...
official/nlp/tasks/tagging_test.py
View file @
48e49875
...
@@ -53,8 +53,8 @@ class TaggingTest(tf.test.TestCase):
...
@@ -53,8 +53,8 @@ class TaggingTest(tf.test.TestCase):
def
setUp
(
self
):
def
setUp
(
self
):
super
(
TaggingTest
,
self
).
setUp
()
super
(
TaggingTest
,
self
).
setUp
()
self
.
_encoder_config
=
encoders
.
Transformer
EncoderConfig
(
self
.
_encoder_config
=
encoders
.
EncoderConfig
(
vocab_size
=
30522
,
num_layers
=
1
)
bert
=
encoders
.
BertEncoderConfig
(
vocab_size
=
30522
,
num_layers
=
1
)
)
self
.
_train_data_config
=
tagging_data_loader
.
TaggingDataConfig
(
self
.
_train_data_config
=
tagging_data_loader
.
TaggingDataConfig
(
input_path
=
"dummy"
,
seq_length
=
128
,
global_batch_size
=
1
)
input_path
=
"dummy"
,
seq_length
=
128
,
global_batch_size
=
1
)
...
@@ -74,7 +74,7 @@ class TaggingTest(tf.test.TestCase):
...
@@ -74,7 +74,7 @@ class TaggingTest(tf.test.TestCase):
def
test_task
(
self
):
def
test_task
(
self
):
# Saves a checkpoint.
# Saves a checkpoint.
encoder
=
encoders
.
instantiate_encoder_from_cfg
(
self
.
_encoder_config
)
encoder
=
encoders
.
build_encoder
(
self
.
_encoder_config
)
ckpt
=
tf
.
train
.
Checkpoint
(
encoder
=
encoder
)
ckpt
=
tf
.
train
.
Checkpoint
(
encoder
=
encoder
)
saved_path
=
ckpt
.
save
(
self
.
get_temp_dir
())
saved_path
=
ckpt
.
save
(
self
.
get_temp_dir
())
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment