Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
21b73d22
Commit
21b73d22
authored
Jun 17, 2020
by
A. Unique TensorFlower
Browse files
Internal change
PiperOrigin-RevId: 317010998
parent
a3263c0f
Changes
9
Show whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
412 additions
and
59 deletions
+412
-59
official/nlp/configs/bert.py
official/nlp/configs/bert.py
+42
-30
official/nlp/configs/bert_test.py
official/nlp/configs/bert_test.py
+4
-4
official/nlp/configs/encoders.py
official/nlp/configs/encoders.py
+26
-1
official/nlp/modeling/models/bert_span_labeler.py
official/nlp/modeling/models/bert_span_labeler.py
+6
-0
official/nlp/tasks/masked_lm.py
official/nlp/tasks/masked_lm.py
+1
-1
official/nlp/tasks/question_answering.py
official/nlp/tasks/question_answering.py
+167
-0
official/nlp/tasks/question_answering_test.py
official/nlp/tasks/question_answering_test.py
+130
-0
official/nlp/tasks/sentence_prediction.py
official/nlp/tasks/sentence_prediction.py
+3
-3
official/nlp/tasks/sentence_prediction_test.py
official/nlp/tasks/sentence_prediction_test.py
+33
-20
No files found.
official/nlp/configs/bert.py
View file @
21b73d22
...
@@ -13,7 +13,10 @@
...
@@ -13,7 +13,10 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
# ==============================================================================
# ==============================================================================
"""A multi-head BERT encoder network for pretraining."""
"""Multi-head BERT encoder network with classification heads.
Includes configurations and instantiation methods.
"""
from
typing
import
List
,
Optional
,
Text
from
typing
import
List
,
Optional
,
Text
import
dataclasses
import
dataclasses
...
@@ -24,7 +27,6 @@ from official.modeling.hyperparams import base_config
...
@@ -24,7 +27,6 @@ from official.modeling.hyperparams import base_config
from
official.modeling.hyperparams
import
config_definitions
as
cfg
from
official.modeling.hyperparams
import
config_definitions
as
cfg
from
official.nlp.configs
import
encoders
from
official.nlp.configs
import
encoders
from
official.nlp.modeling
import
layers
from
official.nlp.modeling
import
layers
from
official.nlp.modeling
import
networks
from
official.nlp.modeling.models
import
bert_pretrainer
from
official.nlp.modeling.models
import
bert_pretrainer
...
@@ -47,43 +49,34 @@ class BertPretrainerConfig(base_config.Config):
...
@@ -47,43 +49,34 @@ class BertPretrainerConfig(base_config.Config):
cls_heads
:
List
[
ClsHeadConfig
]
=
dataclasses
.
field
(
default_factory
=
list
)
cls_heads
:
List
[
ClsHeadConfig
]
=
dataclasses
.
field
(
default_factory
=
list
)
def
instantiate_from_cfg
(
def
instantiate_classification_heads_from_cfgs
(
cls_head_configs
:
List
[
ClsHeadConfig
])
->
List
[
layers
.
ClassificationHead
]:
return
[
layers
.
ClassificationHead
(
**
cfg
.
as_dict
())
for
cfg
in
cls_head_configs
]
if
cls_head_configs
else
[]
def
instantiate_bertpretrainer_from_cfg
(
config
:
BertPretrainerConfig
,
config
:
BertPretrainerConfig
,
encoder_network
:
Optional
[
tf
.
keras
.
Model
]
=
None
):
encoder_network
:
Optional
[
tf
.
keras
.
Model
]
=
None
)
->
bert_pretrainer
.
BertPretrainerV2
:
"""Instantiates a BertPretrainer from the config."""
"""Instantiates a BertPretrainer from the config."""
encoder_cfg
=
config
.
encoder
encoder_cfg
=
config
.
encoder
if
encoder_network
is
None
:
if
encoder_network
is
None
:
encoder_network
=
networks
.
TransformerEncoder
(
encoder_network
=
encoders
.
instantiate_encoder_from_cfg
(
encoder_cfg
)
vocab_size
=
encoder_cfg
.
vocab_size
,
hidden_size
=
encoder_cfg
.
hidden_size
,
num_layers
=
encoder_cfg
.
num_layers
,
num_attention_heads
=
encoder_cfg
.
num_attention_heads
,
intermediate_size
=
encoder_cfg
.
intermediate_size
,
activation
=
tf_utils
.
get_activation
(
encoder_cfg
.
hidden_activation
),
dropout_rate
=
encoder_cfg
.
dropout_rate
,
attention_dropout_rate
=
encoder_cfg
.
attention_dropout_rate
,
max_sequence_length
=
encoder_cfg
.
max_position_embeddings
,
type_vocab_size
=
encoder_cfg
.
type_vocab_size
,
initializer
=
tf
.
keras
.
initializers
.
TruncatedNormal
(
stddev
=
encoder_cfg
.
initializer_range
))
if
config
.
cls_heads
:
classification_heads
=
[
layers
.
ClassificationHead
(
**
cfg
.
as_dict
())
for
cfg
in
config
.
cls_heads
]
else
:
classification_heads
=
[]
return
bert_pretrainer
.
BertPretrainerV2
(
return
bert_pretrainer
.
BertPretrainerV2
(
config
.
num_masked_tokens
,
config
.
num_masked_tokens
,
mlm_activation
=
tf_utils
.
get_activation
(
encoder_cfg
.
hidden_activation
),
mlm_activation
=
tf_utils
.
get_activation
(
encoder_cfg
.
hidden_activation
),
mlm_initializer
=
tf
.
keras
.
initializers
.
TruncatedNormal
(
mlm_initializer
=
tf
.
keras
.
initializers
.
TruncatedNormal
(
stddev
=
encoder_cfg
.
initializer_range
),
stddev
=
encoder_cfg
.
initializer_range
),
encoder_network
=
encoder_network
,
encoder_network
=
encoder_network
,
classification_heads
=
classification_heads
)
classification_heads
=
instantiate_classification_heads_from_cfgs
(
config
.
cls_heads
))
@
dataclasses
.
dataclass
@
dataclasses
.
dataclass
class
BertPretrainDataConfig
(
cfg
.
DataConfig
):
class
BertPretrainDataConfig
(
cfg
.
DataConfig
):
"""Data config for BERT pretraining task."""
"""Data config for BERT pretraining task
(tasks/masked_lm)
."""
input_path
:
str
=
""
input_path
:
str
=
""
global_batch_size
:
int
=
512
global_batch_size
:
int
=
512
is_training
:
bool
=
True
is_training
:
bool
=
True
...
@@ -95,15 +88,15 @@ class BertPretrainDataConfig(cfg.DataConfig):
...
@@ -95,15 +88,15 @@ class BertPretrainDataConfig(cfg.DataConfig):
@
dataclasses
.
dataclass
@
dataclasses
.
dataclass
class
BertPretrainEvalDataConfig
(
BertPretrainDataConfig
):
class
BertPretrainEvalDataConfig
(
BertPretrainDataConfig
):
"""Data config for the eval set in BERT pretraining task."""
"""Data config for the eval set in BERT pretraining task
(tasks/masked_lm)
."""
input_path
:
str
=
""
input_path
:
str
=
""
global_batch_size
:
int
=
512
global_batch_size
:
int
=
512
is_training
:
bool
=
False
is_training
:
bool
=
False
@
dataclasses
.
dataclass
@
dataclasses
.
dataclass
class
Bert
SentencePredictionDataConfig
(
cfg
.
DataConfig
):
class
SentencePredictionDataConfig
(
cfg
.
DataConfig
):
"""Data
of
sentence prediction
dataset
."""
"""Data
config for
sentence prediction
task (tasks/sentence_prediction)
."""
input_path
:
str
=
""
input_path
:
str
=
""
global_batch_size
:
int
=
32
global_batch_size
:
int
=
32
is_training
:
bool
=
True
is_training
:
bool
=
True
...
@@ -111,10 +104,29 @@ class BertSentencePredictionDataConfig(cfg.DataConfig):
...
@@ -111,10 +104,29 @@ class BertSentencePredictionDataConfig(cfg.DataConfig):
@
dataclasses
.
dataclass
@
dataclasses
.
dataclass
class
Bert
SentencePredictionDevDataConfig
(
cfg
.
DataConfig
):
class
SentencePredictionDevDataConfig
(
cfg
.
DataConfig
):
"""Dev
d
ata
of MNLI
sentence prediction
da
tas
et
."""
"""Dev
D
ata
config for
sentence prediction
(
tas
ks/sentence_prediction)
."""
input_path
:
str
=
""
input_path
:
str
=
""
global_batch_size
:
int
=
32
global_batch_size
:
int
=
32
is_training
:
bool
=
False
is_training
:
bool
=
False
seq_length
:
int
=
128
seq_length
:
int
=
128
drop_remainder
:
bool
=
False
drop_remainder
:
bool
=
False
@
dataclasses
.
dataclass
class
QADataConfig
(
cfg
.
DataConfig
):
"""Data config for question answering task (tasks/question_answering)."""
input_path
:
str
=
""
global_batch_size
:
int
=
48
is_training
:
bool
=
True
seq_length
:
int
=
384
@
dataclasses
.
dataclass
class
QADevDataConfig
(
cfg
.
DataConfig
):
"""Dev Data config for queston answering (tasks/question_answering)."""
input_path
:
str
=
""
global_batch_size
:
int
=
48
is_training
:
bool
=
False
seq_length
:
int
=
384
drop_remainder
:
bool
=
False
official/nlp/configs/bert_test.py
View file @
21b73d22
...
@@ -26,7 +26,7 @@ class BertModelsTest(tf.test.TestCase):
...
@@ -26,7 +26,7 @@ class BertModelsTest(tf.test.TestCase):
def
test_network_invocation
(
self
):
def
test_network_invocation
(
self
):
config
=
bert
.
BertPretrainerConfig
(
config
=
bert
.
BertPretrainerConfig
(
encoder
=
encoders
.
TransformerEncoderConfig
(
vocab_size
=
10
,
num_layers
=
1
))
encoder
=
encoders
.
TransformerEncoderConfig
(
vocab_size
=
10
,
num_layers
=
1
))
_
=
bert
.
instantiate_from_cfg
(
config
)
_
=
bert
.
instantiate_
bertpretrainer_
from_cfg
(
config
)
# Invokes with classification heads.
# Invokes with classification heads.
config
=
bert
.
BertPretrainerConfig
(
config
=
bert
.
BertPretrainerConfig
(
...
@@ -35,7 +35,7 @@ class BertModelsTest(tf.test.TestCase):
...
@@ -35,7 +35,7 @@ class BertModelsTest(tf.test.TestCase):
bert
.
ClsHeadConfig
(
bert
.
ClsHeadConfig
(
inner_dim
=
10
,
num_classes
=
2
,
name
=
"next_sentence"
)
inner_dim
=
10
,
num_classes
=
2
,
name
=
"next_sentence"
)
])
])
_
=
bert
.
instantiate_from_cfg
(
config
)
_
=
bert
.
instantiate_
bertpretrainer_
from_cfg
(
config
)
with
self
.
assertRaises
(
ValueError
):
with
self
.
assertRaises
(
ValueError
):
config
=
bert
.
BertPretrainerConfig
(
config
=
bert
.
BertPretrainerConfig
(
...
@@ -47,7 +47,7 @@ class BertModelsTest(tf.test.TestCase):
...
@@ -47,7 +47,7 @@ class BertModelsTest(tf.test.TestCase):
bert
.
ClsHeadConfig
(
bert
.
ClsHeadConfig
(
inner_dim
=
10
,
num_classes
=
2
,
name
=
"next_sentence"
)
inner_dim
=
10
,
num_classes
=
2
,
name
=
"next_sentence"
)
])
])
_
=
bert
.
instantiate_from_cfg
(
config
)
_
=
bert
.
instantiate_
bertpretrainer_
from_cfg
(
config
)
def
test_checkpoint_items
(
self
):
def
test_checkpoint_items
(
self
):
config
=
bert
.
BertPretrainerConfig
(
config
=
bert
.
BertPretrainerConfig
(
...
@@ -56,7 +56,7 @@ class BertModelsTest(tf.test.TestCase):
...
@@ -56,7 +56,7 @@ class BertModelsTest(tf.test.TestCase):
bert
.
ClsHeadConfig
(
bert
.
ClsHeadConfig
(
inner_dim
=
10
,
num_classes
=
2
,
name
=
"next_sentence"
)
inner_dim
=
10
,
num_classes
=
2
,
name
=
"next_sentence"
)
])
])
encoder
=
bert
.
instantiate_from_cfg
(
config
)
encoder
=
bert
.
instantiate_
bertpretrainer_
from_cfg
(
config
)
self
.
assertSameElements
(
encoder
.
checkpoint_items
.
keys
(),
self
.
assertSameElements
(
encoder
.
checkpoint_items
.
keys
(),
[
"encoder"
,
"next_sentence.pooler_dense"
])
[
"encoder"
,
"next_sentence.pooler_dense"
])
...
...
official/nlp/configs/encoders.py
View file @
21b73d22
...
@@ -13,11 +13,17 @@
...
@@ -13,11 +13,17 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
# ==============================================================================
# ==============================================================================
"""Configurations for Encoders."""
"""Transformer Encoders.
Includes configurations and instantiation methods.
"""
import
dataclasses
import
dataclasses
import
tensorflow
as
tf
from
official.modeling
import
tf_utils
from
official.modeling.hyperparams
import
base_config
from
official.modeling.hyperparams
import
base_config
from
official.nlp.modeling
import
networks
@
dataclasses
.
dataclass
@
dataclasses
.
dataclass
...
@@ -34,3 +40,22 @@ class TransformerEncoderConfig(base_config.Config):
...
@@ -34,3 +40,22 @@ class TransformerEncoderConfig(base_config.Config):
max_position_embeddings
:
int
=
512
max_position_embeddings
:
int
=
512
type_vocab_size
:
int
=
2
type_vocab_size
:
int
=
2
initializer_range
:
float
=
0.02
initializer_range
:
float
=
0.02
def
instantiate_encoder_from_cfg
(
config
:
TransformerEncoderConfig
)
->
networks
.
TransformerEncoder
:
"""Instantiate a Transformer encoder network from TransformerEncoderConfig."""
encoder_network
=
networks
.
TransformerEncoder
(
vocab_size
=
config
.
vocab_size
,
hidden_size
=
config
.
hidden_size
,
num_layers
=
config
.
num_layers
,
num_attention_heads
=
config
.
num_attention_heads
,
intermediate_size
=
config
.
intermediate_size
,
activation
=
tf_utils
.
get_activation
(
config
.
hidden_activation
),
dropout_rate
=
config
.
dropout_rate
,
attention_dropout_rate
=
config
.
attention_dropout_rate
,
max_sequence_length
=
config
.
max_position_embeddings
,
type_vocab_size
=
config
.
type_vocab_size
,
initializer
=
tf
.
keras
.
initializers
.
TruncatedNormal
(
stddev
=
config
.
initializer_range
))
return
encoder_network
official/nlp/modeling/models/bert_span_labeler.py
View file @
21b73d22
...
@@ -51,11 +51,13 @@ class BertSpanLabeler(tf.keras.Model):
...
@@ -51,11 +51,13 @@ class BertSpanLabeler(tf.keras.Model):
output
=
'logits'
,
output
=
'logits'
,
**
kwargs
):
**
kwargs
):
self
.
_self_setattr_tracking
=
False
self
.
_self_setattr_tracking
=
False
self
.
_network
=
network
self
.
_config
=
{
self
.
_config
=
{
'network'
:
network
,
'network'
:
network
,
'initializer'
:
initializer
,
'initializer'
:
initializer
,
'output'
:
output
,
'output'
:
output
,
}
}
# We want to use the inputs of the passed network as the inputs to this
# We want to use the inputs of the passed network as the inputs to this
# Model. To do this, we need to keep a handle to the network inputs for use
# Model. To do this, we need to keep a handle to the network inputs for use
# when we construct the Model object at the end of init.
# when we construct the Model object at the end of init.
...
@@ -89,6 +91,10 @@ class BertSpanLabeler(tf.keras.Model):
...
@@ -89,6 +91,10 @@ class BertSpanLabeler(tf.keras.Model):
super
(
BertSpanLabeler
,
self
).
__init__
(
super
(
BertSpanLabeler
,
self
).
__init__
(
inputs
=
inputs
,
outputs
=
logits
,
**
kwargs
)
inputs
=
inputs
,
outputs
=
logits
,
**
kwargs
)
@
property
def
checkpoint_items
(
self
):
return
dict
(
encoder
=
self
.
_network
)
def
get_config
(
self
):
def
get_config
(
self
):
return
self
.
_config
return
self
.
_config
...
...
official/nlp/tasks/masked_lm.py
View file @
21b73d22
...
@@ -40,7 +40,7 @@ class MaskedLMTask(base_task.Task):
...
@@ -40,7 +40,7 @@ class MaskedLMTask(base_task.Task):
"""Mock task object for testing."""
"""Mock task object for testing."""
def
build_model
(
self
):
def
build_model
(
self
):
return
bert
.
instantiate_from_cfg
(
self
.
task_config
.
network
)
return
bert
.
instantiate_
bertpretrainer_
from_cfg
(
self
.
task_config
.
network
)
def
build_losses
(
self
,
def
build_losses
(
self
,
labels
,
labels
,
...
...
official/nlp/tasks/question_answering.py
0 → 100644
View file @
21b73d22
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Question answering task."""
import
logging
import
dataclasses
import
tensorflow
as
tf
import
tensorflow_hub
as
hub
from
official.core
import
base_task
from
official.modeling.hyperparams
import
config_definitions
as
cfg
from
official.nlp.bert
import
input_pipeline
from
official.nlp.configs
import
encoders
from
official.nlp.modeling
import
models
@
dataclasses
.
dataclass
class
QuestionAnsweringConfig
(
cfg
.
TaskConfig
):
"""The model config."""
# At most one of `init_checkpoint` and `hub_module_url` can be specified.
init_checkpoint
:
str
=
''
hub_module_url
:
str
=
''
network
:
encoders
.
TransformerEncoderConfig
=
(
encoders
.
TransformerEncoderConfig
())
train_data
:
cfg
.
DataConfig
=
cfg
.
DataConfig
()
validation_data
:
cfg
.
DataConfig
=
cfg
.
DataConfig
()
@
base_task
.
register_task_cls
(
QuestionAnsweringConfig
)
class
QuestionAnsweringTask
(
base_task
.
Task
):
"""Task object for question answering.
TODO(lehou): Add post-processing.
"""
def
__init__
(
self
,
params
=
cfg
.
TaskConfig
):
super
(
QuestionAnsweringTask
,
self
).
__init__
(
params
)
if
params
.
hub_module_url
and
params
.
init_checkpoint
:
raise
ValueError
(
'At most one of `hub_module_url` and '
'`init_checkpoint` can be specified.'
)
if
params
.
hub_module_url
:
self
.
_hub_module
=
hub
.
load
(
params
.
hub_module_url
)
else
:
self
.
_hub_module
=
None
def
build_model
(
self
):
if
self
.
_hub_module
:
# TODO(lehou): maybe add the hub_module building logic to a util function.
input_word_ids
=
tf
.
keras
.
layers
.
Input
(
shape
=
(
None
,),
dtype
=
tf
.
int32
,
name
=
'input_word_ids'
)
input_mask
=
tf
.
keras
.
layers
.
Input
(
shape
=
(
None
,),
dtype
=
tf
.
int32
,
name
=
'input_mask'
)
input_type_ids
=
tf
.
keras
.
layers
.
Input
(
shape
=
(
None
,),
dtype
=
tf
.
int32
,
name
=
'input_type_ids'
)
bert_model
=
hub
.
KerasLayer
(
self
.
_hub_module
,
trainable
=
True
)
pooled_output
,
sequence_output
=
bert_model
(
[
input_word_ids
,
input_mask
,
input_type_ids
])
encoder_network
=
tf
.
keras
.
Model
(
inputs
=
[
input_word_ids
,
input_mask
,
input_type_ids
],
outputs
=
[
sequence_output
,
pooled_output
])
else
:
encoder_network
=
encoders
.
instantiate_encoder_from_cfg
(
self
.
task_config
.
network
)
return
models
.
BertSpanLabeler
(
network
=
encoder_network
,
initializer
=
tf
.
keras
.
initializers
.
TruncatedNormal
(
stddev
=
self
.
task_config
.
network
.
initializer_range
))
def
build_losses
(
self
,
labels
,
model_outputs
,
aux_losses
=
None
)
->
tf
.
Tensor
:
start_positions
=
labels
[
'start_positions'
]
end_positions
=
labels
[
'end_positions'
]
start_logits
,
end_logits
=
model_outputs
start_loss
=
tf
.
keras
.
losses
.
sparse_categorical_crossentropy
(
start_positions
,
tf
.
cast
(
start_logits
,
dtype
=
tf
.
float32
),
from_logits
=
True
)
end_loss
=
tf
.
keras
.
losses
.
sparse_categorical_crossentropy
(
end_positions
,
tf
.
cast
(
end_logits
,
dtype
=
tf
.
float32
),
from_logits
=
True
)
loss
=
(
tf
.
reduce_mean
(
start_loss
)
+
tf
.
reduce_mean
(
end_loss
))
/
2
return
loss
def
build_inputs
(
self
,
params
,
input_context
=
None
):
"""Returns tf.data.Dataset for sentence_prediction task."""
if
params
.
input_path
==
'dummy'
:
def
dummy_data
(
_
):
dummy_ids
=
tf
.
zeros
((
1
,
params
.
seq_length
),
dtype
=
tf
.
int32
)
x
=
dict
(
input_word_ids
=
dummy_ids
,
input_mask
=
dummy_ids
,
input_type_ids
=
dummy_ids
)
y
=
dict
(
start_positions
=
tf
.
constant
(
0
,
dtype
=
tf
.
int32
),
end_positions
=
tf
.
constant
(
1
,
dtype
=
tf
.
int32
))
return
(
x
,
y
)
dataset
=
tf
.
data
.
Dataset
.
range
(
1
)
dataset
=
dataset
.
repeat
()
dataset
=
dataset
.
map
(
dummy_data
,
num_parallel_calls
=
tf
.
data
.
experimental
.
AUTOTUNE
)
return
dataset
batch_size
=
input_context
.
get_per_replica_batch_size
(
params
.
global_batch_size
)
if
input_context
else
params
.
global_batch_size
# TODO(chendouble): add and use nlp.data.question_answering_dataloader.
dataset
=
input_pipeline
.
create_squad_dataset
(
params
.
input_path
,
params
.
seq_length
,
batch_size
,
is_training
=
params
.
is_training
,
input_pipeline_context
=
input_context
)
return
dataset
def
build_metrics
(
self
,
training
=
None
):
del
training
# TODO(lehou): a list of metrics doesn't work the same as in compile/fit.
metrics
=
[
tf
.
keras
.
metrics
.
SparseCategoricalAccuracy
(
name
=
'start_position_accuracy'
),
tf
.
keras
.
metrics
.
SparseCategoricalAccuracy
(
name
=
'end_position_accuracy'
),
]
return
metrics
def
process_metrics
(
self
,
metrics
,
labels
,
model_outputs
):
metrics
=
dict
([(
metric
.
name
,
metric
)
for
metric
in
metrics
])
start_logits
,
end_logits
=
model_outputs
metrics
[
'start_position_accuracy'
].
update_state
(
labels
[
'start_positions'
],
start_logits
)
metrics
[
'end_position_accuracy'
].
update_state
(
labels
[
'end_positions'
],
end_logits
)
def
process_compiled_metrics
(
self
,
compiled_metrics
,
labels
,
model_outputs
):
start_logits
,
end_logits
=
model_outputs
compiled_metrics
.
update_state
(
y_true
=
labels
,
# labels has keys 'start_positions' and 'end_positions'.
y_pred
=
{
'start_positions'
:
start_logits
,
'end_positions'
:
end_logits
})
def
initialize
(
self
,
model
):
"""Load a pretrained checkpoint (if exists) and then train from iter 0."""
ckpt_dir_or_file
=
self
.
task_config
.
init_checkpoint
if
tf
.
io
.
gfile
.
isdir
(
ckpt_dir_or_file
):
ckpt_dir_or_file
=
tf
.
train
.
latest_checkpoint
(
ckpt_dir_or_file
)
if
not
ckpt_dir_or_file
:
return
ckpt
=
tf
.
train
.
Checkpoint
(
**
model
.
checkpoint_items
)
status
=
ckpt
.
restore
(
ckpt_dir_or_file
)
status
.
expect_partial
().
assert_existing_objects_matched
()
logging
.
info
(
'finished loading pretrained checkpoint from %s'
,
ckpt_dir_or_file
)
official/nlp/tasks/question_answering_test.py
0 → 100644
View file @
21b73d22
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for official.nlp.tasks.question_answering."""
import
functools
import
os
import
tensorflow
as
tf
from
official.nlp.bert
import
configs
from
official.nlp.bert
import
export_tfhub
from
official.nlp.configs
import
bert
from
official.nlp.configs
import
encoders
from
official.nlp.tasks
import
question_answering
class
QuestionAnsweringTaskTest
(
tf
.
test
.
TestCase
):
def
setUp
(
self
):
super
(
QuestionAnsweringTaskTest
,
self
).
setUp
()
self
.
_encoder_config
=
encoders
.
TransformerEncoderConfig
(
vocab_size
=
30522
,
num_layers
=
1
)
self
.
_train_data_config
=
bert
.
QADataConfig
(
input_path
=
"dummy"
,
seq_length
=
128
,
global_batch_size
=
1
)
def
_run_task
(
self
,
config
):
task
=
question_answering
.
QuestionAnsweringTask
(
config
)
model
=
task
.
build_model
()
metrics
=
task
.
build_metrics
()
strategy
=
tf
.
distribute
.
get_strategy
()
dataset
=
strategy
.
experimental_distribute_datasets_from_function
(
functools
.
partial
(
task
.
build_inputs
,
config
.
train_data
))
iterator
=
iter
(
dataset
)
optimizer
=
tf
.
keras
.
optimizers
.
SGD
(
lr
=
0.1
)
task
.
train_step
(
next
(
iterator
),
model
,
optimizer
,
metrics
=
metrics
)
task
.
validation_step
(
next
(
iterator
),
model
,
metrics
=
metrics
)
def
test_task
(
self
):
# Saves a checkpoint.
pretrain_cfg
=
bert
.
BertPretrainerConfig
(
encoder
=
self
.
_encoder_config
,
num_masked_tokens
=
20
,
cls_heads
=
[
bert
.
ClsHeadConfig
(
inner_dim
=
10
,
num_classes
=
3
,
name
=
"next_sentence"
)
])
pretrain_model
=
bert
.
instantiate_bertpretrainer_from_cfg
(
pretrain_cfg
)
ckpt
=
tf
.
train
.
Checkpoint
(
model
=
pretrain_model
,
**
pretrain_model
.
checkpoint_items
)
saved_path
=
ckpt
.
save
(
self
.
get_temp_dir
())
config
=
question_answering
.
QuestionAnsweringConfig
(
init_checkpoint
=
saved_path
,
network
=
self
.
_encoder_config
,
train_data
=
self
.
_train_data_config
)
task
=
question_answering
.
QuestionAnsweringTask
(
config
)
model
=
task
.
build_model
()
metrics
=
task
.
build_metrics
()
dataset
=
task
.
build_inputs
(
config
.
train_data
)
iterator
=
iter
(
dataset
)
optimizer
=
tf
.
keras
.
optimizers
.
SGD
(
lr
=
0.1
)
task
.
train_step
(
next
(
iterator
),
model
,
optimizer
,
metrics
=
metrics
)
task
.
validation_step
(
next
(
iterator
),
model
,
metrics
=
metrics
)
task
.
initialize
(
model
)
def
test_task_with_fit
(
self
):
config
=
question_answering
.
QuestionAnsweringConfig
(
network
=
self
.
_encoder_config
,
train_data
=
self
.
_train_data_config
)
task
=
question_answering
.
QuestionAnsweringTask
(
config
)
model
=
task
.
build_model
()
model
=
task
.
compile_model
(
model
,
optimizer
=
tf
.
keras
.
optimizers
.
SGD
(
lr
=
0.1
),
train_step
=
task
.
train_step
,
metrics
=
[
tf
.
keras
.
metrics
.
SparseCategoricalAccuracy
(
name
=
"accuracy"
)])
dataset
=
task
.
build_inputs
(
config
.
train_data
)
logs
=
model
.
fit
(
dataset
,
epochs
=
1
,
steps_per_epoch
=
2
)
self
.
assertIn
(
"loss"
,
logs
.
history
)
self
.
assertIn
(
"start_positions_accuracy"
,
logs
.
history
)
self
.
assertIn
(
"end_positions_accuracy"
,
logs
.
history
)
def
_export_bert_tfhub
(
self
):
bert_config
=
configs
.
BertConfig
(
vocab_size
=
30522
,
hidden_size
=
16
,
intermediate_size
=
32
,
max_position_embeddings
=
128
,
num_attention_heads
=
2
,
num_hidden_layers
=
1
)
_
,
encoder
=
export_tfhub
.
create_bert_model
(
bert_config
)
model_checkpoint_dir
=
os
.
path
.
join
(
self
.
get_temp_dir
(),
"checkpoint"
)
checkpoint
=
tf
.
train
.
Checkpoint
(
model
=
encoder
)
checkpoint
.
save
(
os
.
path
.
join
(
model_checkpoint_dir
,
"test"
))
model_checkpoint_path
=
tf
.
train
.
latest_checkpoint
(
model_checkpoint_dir
)
vocab_file
=
os
.
path
.
join
(
self
.
get_temp_dir
(),
"uncased_vocab.txt"
)
with
tf
.
io
.
gfile
.
GFile
(
vocab_file
,
"w"
)
as
f
:
f
.
write
(
"dummy content"
)
hub_destination
=
os
.
path
.
join
(
self
.
get_temp_dir
(),
"hub"
)
export_tfhub
.
export_bert_tfhub
(
bert_config
,
model_checkpoint_path
,
hub_destination
,
vocab_file
)
return
hub_destination
def
test_task_with_hub
(
self
):
hub_module_url
=
self
.
_export_bert_tfhub
()
config
=
question_answering
.
QuestionAnsweringConfig
(
hub_module_url
=
hub_module_url
,
network
=
self
.
_encoder_config
,
train_data
=
self
.
_train_data_config
)
self
.
_run_task
(
config
)
if
__name__
==
"__main__"
:
tf
.
test
.
main
()
official/nlp/tasks/sentence_prediction.py
View file @
21b73d22
...
@@ -34,7 +34,7 @@ class SentencePredictionConfig(cfg.TaskConfig):
...
@@ -34,7 +34,7 @@ class SentencePredictionConfig(cfg.TaskConfig):
init_checkpoint
:
str
=
''
init_checkpoint
:
str
=
''
hub_module_url
:
str
=
''
hub_module_url
:
str
=
''
network
:
bert
.
BertPretrainerConfig
=
bert
.
BertPretrainerConfig
(
network
:
bert
.
BertPretrainerConfig
=
bert
.
BertPretrainerConfig
(
num_masked_tokens
=
0
,
num_masked_tokens
=
0
,
# No masked language modeling head.
cls_heads
=
[
cls_heads
=
[
bert
.
ClsHeadConfig
(
bert
.
ClsHeadConfig
(
inner_dim
=
768
,
inner_dim
=
768
,
...
@@ -74,10 +74,10 @@ class SentencePredictionTask(base_task.Task):
...
@@ -74,10 +74,10 @@ class SentencePredictionTask(base_task.Task):
encoder_from_hub
=
tf
.
keras
.
Model
(
encoder_from_hub
=
tf
.
keras
.
Model
(
inputs
=
[
input_word_ids
,
input_mask
,
input_type_ids
],
inputs
=
[
input_word_ids
,
input_mask
,
input_type_ids
],
outputs
=
[
sequence_output
,
pooled_output
])
outputs
=
[
sequence_output
,
pooled_output
])
return
bert
.
instantiate_from_cfg
(
return
bert
.
instantiate_
bertpretrainer_
from_cfg
(
self
.
task_config
.
network
,
encoder_network
=
encoder_from_hub
)
self
.
task_config
.
network
,
encoder_network
=
encoder_from_hub
)
else
:
else
:
return
bert
.
instantiate_from_cfg
(
self
.
task_config
.
network
)
return
bert
.
instantiate_
bertpretrainer_
from_cfg
(
self
.
task_config
.
network
)
def
build_losses
(
self
,
labels
,
model_outputs
,
aux_losses
=
None
)
->
tf
.
Tensor
:
def
build_losses
(
self
,
labels
,
model_outputs
,
aux_losses
=
None
)
->
tf
.
Tensor
:
loss
=
loss_lib
.
weighted_sparse_categorical_crossentropy_loss
(
loss
=
loss_lib
.
weighted_sparse_categorical_crossentropy_loss
(
...
...
official/nlp/tasks/sentence_prediction_test.py
View file @
21b73d22
...
@@ -27,6 +27,19 @@ from official.nlp.tasks import sentence_prediction
...
@@ -27,6 +27,19 @@ from official.nlp.tasks import sentence_prediction
class
SentencePredictionTaskTest
(
tf
.
test
.
TestCase
):
class
SentencePredictionTaskTest
(
tf
.
test
.
TestCase
):
def
setUp
(
self
):
super
(
SentencePredictionTaskTest
,
self
).
setUp
()
self
.
_network_config
=
bert
.
BertPretrainerConfig
(
encoder
=
encoders
.
TransformerEncoderConfig
(
vocab_size
=
30522
,
num_layers
=
1
),
num_masked_tokens
=
0
,
cls_heads
=
[
bert
.
ClsHeadConfig
(
inner_dim
=
10
,
num_classes
=
3
,
name
=
"sentence_prediction"
)
])
self
.
_train_data_config
=
bert
.
SentencePredictionDataConfig
(
input_path
=
"dummy"
,
seq_length
=
128
,
global_batch_size
=
1
)
def
_run_task
(
self
,
config
):
def
_run_task
(
self
,
config
):
task
=
sentence_prediction
.
SentencePredictionTask
(
config
)
task
=
sentence_prediction
.
SentencePredictionTask
(
config
)
model
=
task
.
build_model
()
model
=
task
.
build_model
()
...
@@ -44,16 +57,8 @@ class SentencePredictionTaskTest(tf.test.TestCase):
...
@@ -44,16 +57,8 @@ class SentencePredictionTaskTest(tf.test.TestCase):
def
test_task
(
self
):
def
test_task
(
self
):
config
=
sentence_prediction
.
SentencePredictionConfig
(
config
=
sentence_prediction
.
SentencePredictionConfig
(
init_checkpoint
=
self
.
get_temp_dir
(),
init_checkpoint
=
self
.
get_temp_dir
(),
network
=
bert
.
BertPretrainerConfig
(
network
=
self
.
_network_config
,
encoder
=
encoders
.
TransformerEncoderConfig
(
train_data
=
self
.
_train_data_config
)
vocab_size
=
30522
,
num_layers
=
1
),
num_masked_tokens
=
0
,
cls_heads
=
[
bert
.
ClsHeadConfig
(
inner_dim
=
10
,
num_classes
=
3
,
name
=
"sentence_prediction"
)
]),
train_data
=
bert
.
BertSentencePredictionDataConfig
(
input_path
=
"dummy"
,
seq_length
=
128
,
global_batch_size
=
1
))
task
=
sentence_prediction
.
SentencePredictionTask
(
config
)
task
=
sentence_prediction
.
SentencePredictionTask
(
config
)
model
=
task
.
build_model
()
model
=
task
.
build_model
()
metrics
=
task
.
build_metrics
()
metrics
=
task
.
build_metrics
()
...
@@ -73,12 +78,27 @@ class SentencePredictionTaskTest(tf.test.TestCase):
...
@@ -73,12 +78,27 @@ class SentencePredictionTaskTest(tf.test.TestCase):
bert
.
ClsHeadConfig
(
bert
.
ClsHeadConfig
(
inner_dim
=
10
,
num_classes
=
3
,
name
=
"next_sentence"
)
inner_dim
=
10
,
num_classes
=
3
,
name
=
"next_sentence"
)
])
])
pretrain_model
=
bert
.
instantiate_from_cfg
(
pretrain_cfg
)
pretrain_model
=
bert
.
instantiate_
bertpretrainer_
from_cfg
(
pretrain_cfg
)
ckpt
=
tf
.
train
.
Checkpoint
(
ckpt
=
tf
.
train
.
Checkpoint
(
model
=
pretrain_model
,
**
pretrain_model
.
checkpoint_items
)
model
=
pretrain_model
,
**
pretrain_model
.
checkpoint_items
)
ckpt
.
save
(
config
.
init_checkpoint
)
ckpt
.
save
(
config
.
init_checkpoint
)
task
.
initialize
(
model
)
task
.
initialize
(
model
)
def
test_task_with_fit
(
self
):
config
=
sentence_prediction
.
SentencePredictionConfig
(
network
=
self
.
_network_config
,
train_data
=
self
.
_train_data_config
)
task
=
sentence_prediction
.
SentencePredictionTask
(
config
)
model
=
task
.
build_model
()
model
=
task
.
compile_model
(
model
,
optimizer
=
tf
.
keras
.
optimizers
.
SGD
(
lr
=
0.1
),
train_step
=
task
.
train_step
,
metrics
=
task
.
build_metrics
())
dataset
=
task
.
build_inputs
(
config
.
train_data
)
logs
=
model
.
fit
(
dataset
,
epochs
=
1
,
steps_per_epoch
=
2
)
self
.
assertIn
(
"loss"
,
logs
.
history
)
def
_export_bert_tfhub
(
self
):
def
_export_bert_tfhub
(
self
):
bert_config
=
configs
.
BertConfig
(
bert_config
=
configs
.
BertConfig
(
vocab_size
=
30522
,
vocab_size
=
30522
,
...
@@ -106,15 +126,8 @@ class SentencePredictionTaskTest(tf.test.TestCase):
...
@@ -106,15 +126,8 @@ class SentencePredictionTaskTest(tf.test.TestCase):
hub_module_url
=
self
.
_export_bert_tfhub
()
hub_module_url
=
self
.
_export_bert_tfhub
()
config
=
sentence_prediction
.
SentencePredictionConfig
(
config
=
sentence_prediction
.
SentencePredictionConfig
(
hub_module_url
=
hub_module_url
,
hub_module_url
=
hub_module_url
,
network
=
bert
.
BertPretrainerConfig
(
network
=
self
.
_network_config
,
encoders
.
TransformerEncoderConfig
(
vocab_size
=
30522
,
num_layers
=
1
),
train_data
=
self
.
_train_data_config
)
num_masked_tokens
=
0
,
cls_heads
=
[
bert
.
ClsHeadConfig
(
inner_dim
=
10
,
num_classes
=
3
,
name
=
"sentence_prediction"
)
]),
train_data
=
bert
.
BertSentencePredictionDataConfig
(
input_path
=
"dummy"
,
seq_length
=
128
,
global_batch_size
=
10
))
self
.
_run_task
(
config
)
self
.
_run_task
(
config
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment