Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
d48d036a
Commit
d48d036a
authored
Nov 11, 2019
by
Chen Chen
Committed by
A. Unique TensorFlower
Nov 11, 2019
Browse files
Support to use KerasBERT for squad task.
PiperOrigin-RevId: 279873276
parent
146a37c6
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
71 additions
and
31 deletions
+71
-31
official/nlp/bert/input_pipeline.py
official/nlp/bert/input_pipeline.py
+7
-1
official/nlp/bert/run_squad.py
official/nlp/bert/run_squad.py
+13
-5
official/nlp/bert_models.py
official/nlp/bert_models.py
+51
-25
No files found.
official/nlp/bert/input_pipeline.py
View file @
d48d036a
...
@@ -177,7 +177,6 @@ def create_classifier_dataset(file_path,
...
@@ -177,7 +177,6 @@ def create_classifier_dataset(file_path,
def
create_squad_dataset
(
file_path
,
seq_length
,
batch_size
,
is_training
=
True
):
def
create_squad_dataset
(
file_path
,
seq_length
,
batch_size
,
is_training
=
True
):
"""Creates input dataset from (tf)records files for train/eval."""
"""Creates input dataset from (tf)records files for train/eval."""
name_to_features
=
{
name_to_features
=
{
'unique_ids'
:
tf
.
io
.
FixedLenFeature
([],
tf
.
int64
),
'input_ids'
:
tf
.
io
.
FixedLenFeature
([
seq_length
],
tf
.
int64
),
'input_ids'
:
tf
.
io
.
FixedLenFeature
([
seq_length
],
tf
.
int64
),
'input_mask'
:
tf
.
io
.
FixedLenFeature
([
seq_length
],
tf
.
int64
),
'input_mask'
:
tf
.
io
.
FixedLenFeature
([
seq_length
],
tf
.
int64
),
'segment_ids'
:
tf
.
io
.
FixedLenFeature
([
seq_length
],
tf
.
int64
),
'segment_ids'
:
tf
.
io
.
FixedLenFeature
([
seq_length
],
tf
.
int64
),
...
@@ -185,15 +184,22 @@ def create_squad_dataset(file_path, seq_length, batch_size, is_training=True):
...
@@ -185,15 +184,22 @@ def create_squad_dataset(file_path, seq_length, batch_size, is_training=True):
if
is_training
:
if
is_training
:
name_to_features
[
'start_positions'
]
=
tf
.
io
.
FixedLenFeature
([],
tf
.
int64
)
name_to_features
[
'start_positions'
]
=
tf
.
io
.
FixedLenFeature
([],
tf
.
int64
)
name_to_features
[
'end_positions'
]
=
tf
.
io
.
FixedLenFeature
([],
tf
.
int64
)
name_to_features
[
'end_positions'
]
=
tf
.
io
.
FixedLenFeature
([],
tf
.
int64
)
else
:
name_to_features
[
'unique_ids'
]
=
tf
.
io
.
FixedLenFeature
([],
tf
.
int64
)
input_fn
=
file_based_input_fn_builder
(
file_path
,
name_to_features
)
input_fn
=
file_based_input_fn_builder
(
file_path
,
name_to_features
)
dataset
=
input_fn
()
dataset
=
input_fn
()
def
_select_data_from_record
(
record
):
def
_select_data_from_record
(
record
):
"""Dispatches record to features and labels."""
x
,
y
=
{},
{}
x
,
y
=
{},
{}
for
name
,
tensor
in
record
.
items
():
for
name
,
tensor
in
record
.
items
():
if
name
in
(
'start_positions'
,
'end_positions'
):
if
name
in
(
'start_positions'
,
'end_positions'
):
y
[
name
]
=
tensor
y
[
name
]
=
tensor
elif
name
==
'input_ids'
:
x
[
'input_word_ids'
]
=
tensor
elif
name
==
'segment_ids'
:
x
[
'input_type_ids'
]
=
tensor
else
:
else
:
x
[
name
]
=
tensor
x
[
name
]
=
tensor
return
(
x
,
y
)
return
(
x
,
y
)
...
...
official/nlp/bert/run_squad.py
View file @
d48d036a
...
@@ -80,6 +80,10 @@ flags.DEFINE_integer(
...
@@ -80,6 +80,10 @@ flags.DEFINE_integer(
'max_answer_length'
,
30
,
'max_answer_length'
,
30
,
'The maximum length of an answer that can be generated. This is needed '
'The maximum length of an answer that can be generated. This is needed '
'because the start and end predictions are not conditioned on one another.'
)
'because the start and end predictions are not conditioned on one another.'
)
flags
.
DEFINE_bool
(
'use_keras_bert_for_squad'
,
False
,
'Whether to use keras BERT for squad '
'task. Note that when the FLAG "hub_module_url" is specified, '
'"use_keras_bert_for_squad" cannot be True.'
)
common_flags
.
define_common_bert_flags
()
common_flags
.
define_common_bert_flags
()
...
@@ -108,7 +112,7 @@ def get_loss_fn(loss_factor=1.0):
...
@@ -108,7 +112,7 @@ def get_loss_fn(loss_factor=1.0):
def
_loss_fn
(
labels
,
model_outputs
):
def
_loss_fn
(
labels
,
model_outputs
):
start_positions
=
labels
[
'start_positions'
]
start_positions
=
labels
[
'start_positions'
]
end_positions
=
labels
[
'end_positions'
]
end_positions
=
labels
[
'end_positions'
]
_
,
start_logits
,
end_logits
=
model_outputs
start_logits
,
end_logits
=
model_outputs
return
squad_loss_fn
(
return
squad_loss_fn
(
start_positions
,
start_positions
,
end_positions
,
end_positions
,
...
@@ -147,7 +151,8 @@ def predict_squad_customized(strategy, input_meta_data, bert_config,
...
@@ -147,7 +151,8 @@ def predict_squad_customized(strategy, input_meta_data, bert_config,
# Prediction always uses float32, even if training uses mixed precision.
# Prediction always uses float32, even if training uses mixed precision.
tf
.
keras
.
mixed_precision
.
experimental
.
set_policy
(
'float32'
)
tf
.
keras
.
mixed_precision
.
experimental
.
set_policy
(
'float32'
)
squad_model
,
_
=
bert_models
.
squad_model
(
squad_model
,
_
=
bert_models
.
squad_model
(
bert_config
,
input_meta_data
[
'max_seq_length'
],
float_type
=
tf
.
float32
)
bert_config
,
input_meta_data
[
'max_seq_length'
],
float_type
=
tf
.
float32
,
use_keras_bert
=
FLAGS
.
use_keras_bert_for_squad
)
checkpoint_path
=
tf
.
train
.
latest_checkpoint
(
FLAGS
.
model_dir
)
checkpoint_path
=
tf
.
train
.
latest_checkpoint
(
FLAGS
.
model_dir
)
logging
.
info
(
'Restoring checkpoints from %s'
,
checkpoint_path
)
logging
.
info
(
'Restoring checkpoints from %s'
,
checkpoint_path
)
...
@@ -161,7 +166,8 @@ def predict_squad_customized(strategy, input_meta_data, bert_config,
...
@@ -161,7 +166,8 @@ def predict_squad_customized(strategy, input_meta_data, bert_config,
def
_replicated_step
(
inputs
):
def
_replicated_step
(
inputs
):
"""Replicated prediction calculation."""
"""Replicated prediction calculation."""
x
,
_
=
inputs
x
,
_
=
inputs
unique_ids
,
start_logits
,
end_logits
=
squad_model
(
x
,
training
=
False
)
unique_ids
=
x
.
pop
(
'unique_ids'
)
start_logits
,
end_logits
=
squad_model
(
x
,
training
=
False
)
return
dict
(
return
dict
(
unique_ids
=
unique_ids
,
unique_ids
=
unique_ids
,
start_logits
=
start_logits
,
start_logits
=
start_logits
,
...
@@ -216,7 +222,8 @@ def train_squad(strategy,
...
@@ -216,7 +222,8 @@ def train_squad(strategy,
bert_config
,
bert_config
,
max_seq_length
,
max_seq_length
,
float_type
=
tf
.
float16
if
use_float16
else
tf
.
float32
,
float_type
=
tf
.
float16
if
use_float16
else
tf
.
float32
,
hub_module_url
=
FLAGS
.
hub_module_url
)
hub_module_url
=
FLAGS
.
hub_module_url
,
use_keras_bert
=
FLAGS
.
use_keras_bert_for_squad
)
squad_model
.
optimizer
=
optimization
.
create_optimizer
(
squad_model
.
optimizer
=
optimization
.
create_optimizer
(
FLAGS
.
learning_rate
,
steps_per_epoch
*
epochs
,
warmup_steps
)
FLAGS
.
learning_rate
,
steps_per_epoch
*
epochs
,
warmup_steps
)
if
use_float16
:
if
use_float16
:
...
@@ -340,7 +347,8 @@ def export_squad(model_export_path, input_meta_data):
...
@@ -340,7 +347,8 @@ def export_squad(model_export_path, input_meta_data):
bert_config
=
modeling
.
BertConfig
.
from_json_file
(
FLAGS
.
bert_config_file
)
bert_config
=
modeling
.
BertConfig
.
from_json_file
(
FLAGS
.
bert_config_file
)
squad_model
,
_
=
bert_models
.
squad_model
(
squad_model
,
_
=
bert_models
.
squad_model
(
bert_config
,
input_meta_data
[
'max_seq_length'
],
float_type
=
tf
.
float32
)
bert_config
,
input_meta_data
[
'max_seq_length'
],
float_type
=
tf
.
float32
,
use_keras_bert
=
FLAGS
.
use_keras_bert_for_squad
)
model_saving_utils
.
export_bert_model
(
model_saving_utils
.
export_bert_model
(
model_export_path
,
model
=
squad_model
,
checkpoint_dir
=
FLAGS
.
model_dir
)
model_export_path
,
model
=
squad_model
,
checkpoint_dir
=
FLAGS
.
model_dir
)
...
...
official/nlp/bert_models.py
View file @
d48d036a
...
@@ -26,6 +26,7 @@ from official.modeling import tf_utils
...
@@ -26,6 +26,7 @@ from official.modeling import tf_utils
from
official.nlp
import
bert_modeling
as
modeling
from
official.nlp
import
bert_modeling
as
modeling
from
official.nlp.modeling
import
networks
from
official.nlp.modeling
import
networks
from
official.nlp.modeling.networks
import
bert_classifier
from
official.nlp.modeling.networks
import
bert_classifier
from
official.nlp.modeling.networks
import
bert_span_labeler
def
gather_indexes
(
sequence_tensor
,
positions
):
def
gather_indexes
(
sequence_tensor
,
positions
):
...
@@ -224,6 +225,32 @@ class BertPretrainLossAndMetricLayer(tf.keras.layers.Layer):
...
@@ -224,6 +225,32 @@ class BertPretrainLossAndMetricLayer(tf.keras.layers.Layer):
return
final_loss
return
final_loss
def
_get_transformer_encoder
(
bert_config
,
sequence_length
):
"""Gets a 'TransformerEncoder' object.
Args:
bert_config: A 'modeling.BertConfig' object.
sequence_length: Maximum sequence length of the training data.
Returns:
A networks.TransformerEncoder object.
"""
return
networks
.
TransformerEncoder
(
vocab_size
=
bert_config
.
vocab_size
,
hidden_size
=
bert_config
.
hidden_size
,
num_layers
=
bert_config
.
num_hidden_layers
,
num_attention_heads
=
bert_config
.
num_attention_heads
,
intermediate_size
=
bert_config
.
intermediate_size
,
activation
=
tf_utils
.
get_activation
(
'gelu'
),
dropout_rate
=
bert_config
.
hidden_dropout_prob
,
attention_dropout_rate
=
bert_config
.
attention_probs_dropout_prob
,
sequence_length
=
sequence_length
,
max_sequence_length
=
bert_config
.
max_position_embeddings
,
type_vocab_size
=
bert_config
.
type_vocab_size
,
initializer
=
tf
.
keras
.
initializers
.
TruncatedNormal
(
stddev
=
bert_config
.
initializer_range
))
def
pretrain_model
(
bert_config
,
def
pretrain_model
(
bert_config
,
seq_length
,
seq_length
,
max_predictions_per_seq
,
max_predictions_per_seq
,
...
@@ -333,7 +360,8 @@ def squad_model(bert_config,
...
@@ -333,7 +360,8 @@ def squad_model(bert_config,
max_seq_length
,
max_seq_length
,
float_type
,
float_type
,
initializer
=
None
,
initializer
=
None
,
hub_module_url
=
None
):
hub_module_url
=
None
,
use_keras_bert
=
False
):
"""Returns BERT Squad model along with core BERT model to import weights.
"""Returns BERT Squad model along with core BERT model to import weights.
Args:
Args:
...
@@ -342,19 +370,31 @@ def squad_model(bert_config,
...
@@ -342,19 +370,31 @@ def squad_model(bert_config,
float_type: tf.dtype, tf.float32 or tf.bfloat16.
float_type: tf.dtype, tf.float32 or tf.bfloat16.
initializer: Initializer for weights in BertSquadLogitsLayer.
initializer: Initializer for weights in BertSquadLogitsLayer.
hub_module_url: TF-Hub path/url to Bert module.
hub_module_url: TF-Hub path/url to Bert module.
use_keras_bert: Whether to use keras BERT. Note that when the above
'hub_module_url' is specified, 'use_keras_bert' cannot be True.
Returns:
Returns:
Two tensors, start logits and end logits, [batch x sequence length].
A tuple of (1) keras model that outputs start logits and end logits and
(2) the core BERT transformer encoder.
Raises:
ValueError: When 'hub_module_url' is specified and 'use_keras_bert' is True.
"""
"""
unique_ids
=
tf
.
keras
.
layers
.
Input
(
if
hub_module_url
and
use_keras_bert
:
shape
=
(
1
,),
dtype
=
tf
.
int32
,
name
=
'unique_ids'
)
raise
ValueError
(
'Cannot use hub_module_url and keras BERT at the same time.'
)
if
use_keras_bert
:
bert_encoder
=
_get_transformer_encoder
(
bert_config
,
max_seq_length
)
return
bert_span_labeler
.
BertSpanLabeler
(
network
=
bert_encoder
),
bert_encoder
input_word_ids
=
tf
.
keras
.
layers
.
Input
(
input_word_ids
=
tf
.
keras
.
layers
.
Input
(
shape
=
(
max_seq_length
,),
dtype
=
tf
.
int32
,
name
=
'input_ids'
)
shape
=
(
max_seq_length
,),
dtype
=
tf
.
int32
,
name
=
'input_
word_
ids'
)
input_mask
=
tf
.
keras
.
layers
.
Input
(
input_mask
=
tf
.
keras
.
layers
.
Input
(
shape
=
(
max_seq_length
,),
dtype
=
tf
.
int32
,
name
=
'input_mask'
)
shape
=
(
max_seq_length
,),
dtype
=
tf
.
int32
,
name
=
'input_mask'
)
input_type_ids
=
tf
.
keras
.
layers
.
Input
(
input_type_ids
=
tf
.
keras
.
layers
.
Input
(
shape
=
(
max_seq_length
,),
dtype
=
tf
.
int32
,
name
=
'segment_ids'
)
shape
=
(
max_seq_length
,),
dtype
=
tf
.
int32
,
name
=
'input_type_ids'
)
if
hub_module_url
:
if
hub_module_url
:
core_model
=
hub
.
KerasLayer
(
hub_module_url
,
trainable
=
True
)
core_model
=
hub
.
KerasLayer
(
hub_module_url
,
trainable
=
True
)
_
,
sequence_output
=
core_model
(
_
,
sequence_output
=
core_model
(
...
@@ -383,12 +423,11 @@ def squad_model(bert_config,
...
@@ -383,12 +423,11 @@ def squad_model(bert_config,
squad
=
tf
.
keras
.
Model
(
squad
=
tf
.
keras
.
Model
(
inputs
=
{
inputs
=
{
'unique_ids'
:
unique_ids
,
'input_word_ids'
:
input_word_ids
,
'input_ids'
:
input_word_ids
,
'input_mask'
:
input_mask
,
'input_mask'
:
input_mask
,
'
segment
_ids'
:
input_type_ids
,
'
input_type
_ids'
:
input_type_ids
,
},
},
outputs
=
[
unique_ids
,
start_logits
,
end_logits
],
outputs
=
[
start_logits
,
end_logits
],
name
=
'squad_model'
)
name
=
'squad_model'
)
return
squad
,
core_model
return
squad
,
core_model
...
@@ -424,20 +463,7 @@ def classifier_model(bert_config,
...
@@ -424,20 +463,7 @@ def classifier_model(bert_config,
stddev
=
bert_config
.
initializer_range
)
stddev
=
bert_config
.
initializer_range
)
if
not
hub_module_url
:
if
not
hub_module_url
:
bert_encoder
=
networks
.
TransformerEncoder
(
bert_encoder
=
_get_transformer_encoder
(
bert_config
,
max_seq_length
)
vocab_size
=
bert_config
.
vocab_size
,
hidden_size
=
bert_config
.
hidden_size
,
num_layers
=
bert_config
.
num_hidden_layers
,
num_attention_heads
=
bert_config
.
num_attention_heads
,
intermediate_size
=
bert_config
.
intermediate_size
,
activation
=
tf_utils
.
get_activation
(
'gelu'
),
dropout_rate
=
bert_config
.
hidden_dropout_prob
,
attention_dropout_rate
=
bert_config
.
attention_probs_dropout_prob
,
sequence_length
=
max_seq_length
,
max_sequence_length
=
bert_config
.
max_position_embeddings
,
type_vocab_size
=
bert_config
.
type_vocab_size
,
initializer
=
tf
.
keras
.
initializers
.
TruncatedNormal
(
stddev
=
bert_config
.
initializer_range
))
return
bert_classifier
.
BertClassifier
(
return
bert_classifier
.
BertClassifier
(
bert_encoder
,
bert_encoder
,
num_classes
=
num_labels
,
num_classes
=
num_labels
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment