Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
81d031d0
Commit
81d031d0
authored
Nov 18, 2019
by
Hongkun Yu
Committed by
A. Unique TensorFlower
Nov 18, 2019
Browse files
Internal change
PiperOrigin-RevId: 281117886
parent
c1ac2bfc
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
125 additions
and
91 deletions
+125
-91
official/modeling/model_training_utils.py
official/modeling/model_training_utils.py
+4
-7
official/modeling/model_training_utils_test.py
official/modeling/model_training_utils_test.py
+12
-9
official/nlp/bert/input_pipeline.py
official/nlp/bert/input_pipeline.py
+37
-28
official/nlp/bert/run_classifier.py
official/nlp/bert/run_classifier.py
+32
-12
official/nlp/bert/run_pretraining.py
official/nlp/bert/run_pretraining.py
+11
-28
official/nlp/bert/run_squad.py
official/nlp/bert/run_squad.py
+29
-7
No files found.
official/modeling/model_training_utils.py
View file @
81d031d0
...
@@ -41,16 +41,13 @@ def _save_checkpoint(checkpoint, model_dir, checkpoint_prefix):
...
@@ -41,16 +41,13 @@ def _save_checkpoint(checkpoint, model_dir, checkpoint_prefix):
def
_get_input_iterator
(
input_fn
,
strategy
):
def
_get_input_iterator
(
input_fn
,
strategy
):
"""Returns distributed dataset iterator."""
"""Returns distributed dataset iterator."""
# When training with TPU pods, datasets needs to be cloned across
# When training with TPU pods, datasets needs to be cloned across
# workers. Since Dataset instance cannot be cloned in eager mode, we instead
# workers. Since Dataset instance cannot be cloned in eager mode, we instead
# pass callable that returns a dataset.
# pass callable that returns a dataset.
input_data
=
input_fn
()
if
not
callable
(
input_fn
):
if
callable
(
input_data
):
raise
ValueError
(
'`input_fn` should be a closure that returns a dataset.'
)
iterator
=
iter
(
iterator
=
iter
(
strategy
.
experimental_distribute_datasets_from_function
(
input_data
))
strategy
.
experimental_distribute_datasets_from_function
(
input_fn
))
else
:
iterator
=
iter
(
strategy
.
experimental_distribute_dataset
(
input_data
))
return
iterator
return
iterator
...
...
official/modeling/model_training_utils_test.py
View file @
81d031d0
...
@@ -66,12 +66,15 @@ def create_fake_data_input_fn(batch_size, features_shape, num_classes):
...
@@ -66,12 +66,15 @@ def create_fake_data_input_fn(batch_size, features_shape, num_classes):
An input function that is usable in the executor.
An input function that is usable in the executor.
"""
"""
def
_
input_fn
(
):
def
_
dataset_fn
(
input_context
=
None
):
"""An input function for generating fake data."""
"""An input function for generating fake data."""
local_batch_size
=
input_context
.
get_per_replica_batch_size
(
batch_size
)
features
=
np
.
random
.
rand
(
64
,
*
features_shape
)
features
=
np
.
random
.
rand
(
64
,
*
features_shape
)
labels
=
np
.
random
.
randint
(
2
,
size
=
[
64
,
num_classes
])
labels
=
np
.
random
.
randint
(
2
,
size
=
[
64
,
num_classes
])
# Convert the inputs to a Dataset.
# Convert the inputs to a Dataset.
dataset
=
tf
.
data
.
Dataset
.
from_tensor_slices
((
features
,
labels
))
dataset
=
tf
.
data
.
Dataset
.
from_tensor_slices
((
features
,
labels
))
dataset
=
dataset
.
shard
(
input_context
.
num_input_pipelines
,
input_context
.
input_pipeline_id
)
def
_assign_dtype
(
features
,
labels
):
def
_assign_dtype
(
features
,
labels
):
features
=
tf
.
cast
(
features
,
tf
.
float32
)
features
=
tf
.
cast
(
features
,
tf
.
float32
)
...
@@ -81,11 +84,11 @@ def create_fake_data_input_fn(batch_size, features_shape, num_classes):
...
@@ -81,11 +84,11 @@ def create_fake_data_input_fn(batch_size, features_shape, num_classes):
# Shuffle, repeat, and batch the examples.
# Shuffle, repeat, and batch the examples.
dataset
=
dataset
.
map
(
_assign_dtype
)
dataset
=
dataset
.
map
(
_assign_dtype
)
dataset
=
dataset
.
shuffle
(
64
).
repeat
()
dataset
=
dataset
.
shuffle
(
64
).
repeat
()
dataset
=
dataset
.
batch
(
batch_size
,
drop_remainder
=
True
)
dataset
=
dataset
.
batch
(
local_
batch_size
,
drop_remainder
=
True
)
dataset
=
dataset
.
prefetch
(
buffer_size
=
64
)
dataset
=
dataset
.
prefetch
(
buffer_size
=
64
)
return
dataset
return
dataset
return
_
inpu
t_fn
return
_
datase
t_fn
def
create_model_fn
(
input_shape
,
num_classes
,
use_float16
=
False
):
def
create_model_fn
(
input_shape
,
num_classes
,
use_float16
=
False
):
...
@@ -134,21 +137,21 @@ class ModelTrainingUtilsTest(tf.test.TestCase, parameterized.TestCase):
...
@@ -134,21 +137,21 @@ class ModelTrainingUtilsTest(tf.test.TestCase, parameterized.TestCase):
def
setUp
(
self
):
def
setUp
(
self
):
super
(
ModelTrainingUtilsTest
,
self
).
setUp
()
super
(
ModelTrainingUtilsTest
,
self
).
setUp
()
self
.
_input_fn
=
create_fake_data_input_fn
(
batch_size
=
8
,
features_shape
=
[
128
],
num_classes
=
3
)
self
.
_model_fn
=
create_model_fn
(
input_shape
=
[
128
],
num_classes
=
3
)
self
.
_model_fn
=
create_model_fn
(
input_shape
=
[
128
],
num_classes
=
3
)
def
run_training
(
self
,
distribution
,
model_dir
,
steps_per_loop
,
run_eagerly
):
def
run_training
(
self
,
strategy
,
model_dir
,
steps_per_loop
,
run_eagerly
):
input_fn
=
create_fake_data_input_fn
(
batch_size
=
8
,
features_shape
=
[
128
],
num_classes
=
3
)
model_training_utils
.
run_customized_training_loop
(
model_training_utils
.
run_customized_training_loop
(
strategy
=
di
str
ibution
,
strategy
=
str
ategy
,
model_fn
=
self
.
_model_fn
,
model_fn
=
self
.
_model_fn
,
loss_fn
=
tf
.
keras
.
losses
.
categorical_crossentropy
,
loss_fn
=
tf
.
keras
.
losses
.
categorical_crossentropy
,
model_dir
=
model_dir
,
model_dir
=
model_dir
,
steps_per_epoch
=
20
,
steps_per_epoch
=
20
,
steps_per_loop
=
steps_per_loop
,
steps_per_loop
=
steps_per_loop
,
epochs
=
2
,
epochs
=
2
,
train_input_fn
=
self
.
_
input_fn
,
train_input_fn
=
input_fn
,
eval_input_fn
=
self
.
_
input_fn
,
eval_input_fn
=
input_fn
,
eval_steps
=
10
,
eval_steps
=
10
,
init_checkpoint
=
None
,
init_checkpoint
=
None
,
metric_fn
=
metric_fn
,
metric_fn
=
metric_fn
,
...
...
official/nlp/bert/input_pipeline.py
View file @
81d031d0
...
@@ -36,27 +36,22 @@ def decode_record(record, name_to_features):
...
@@ -36,27 +36,22 @@ def decode_record(record, name_to_features):
return
example
return
example
def
file_based_input_fn_builder
(
input_file
,
name_to_features
):
def
single_file_dataset
(
input_file
,
name_to_features
):
"""Creates an `input_fn` closure to be passed for BERT custom training."""
"""Creates a single-file dataset to be passed for BERT custom training."""
# For training, we want a lot of parallel reading and shuffling.
def
input_fn
():
# For eval, we want no shuffling and parallel reading doesn't matter.
"""Returns dataset for training/evaluation."""
d
=
tf
.
data
.
TFRecordDataset
(
input_file
)
# For training, we want a lot of parallel reading and shuffling.
d
=
d
.
map
(
lambda
record
:
decode_record
(
record
,
name_to_features
))
# For eval, we want no shuffling and parallel reading doesn't matter.
d
=
tf
.
data
.
TFRecordDataset
(
input_file
)
# When `input_file` is a path to a single file or a list
d
=
d
.
map
(
lambda
record
:
decode_record
(
record
,
name_to_features
))
# containing a single path, disable auto sharding so that
# same input file is sent to all workers.
# When `input_file` is a path to a single file or a list
if
isinstance
(
input_file
,
str
)
or
len
(
input_file
)
==
1
:
# containing a single path, disable auto sharding so that
options
=
tf
.
data
.
Options
()
# same input file is sent to all workers.
options
.
experimental_distribute
.
auto_shard_policy
=
(
if
isinstance
(
input_file
,
str
)
or
len
(
input_file
)
==
1
:
tf
.
data
.
experimental
.
AutoShardPolicy
.
OFF
)
options
=
tf
.
data
.
Options
()
d
=
d
.
with_options
(
options
)
options
.
experimental_distribute
.
auto_shard_policy
=
(
return
d
tf
.
data
.
experimental
.
AutoShardPolicy
.
OFF
)
d
=
d
.
with_options
(
options
)
return
d
return
input_fn
def
create_pretrain_dataset
(
input_patterns
,
def
create_pretrain_dataset
(
input_patterns
,
...
@@ -142,7 +137,7 @@ def create_classifier_dataset(file_path,
...
@@ -142,7 +137,7 @@ def create_classifier_dataset(file_path,
seq_length
,
seq_length
,
batch_size
,
batch_size
,
is_training
=
True
,
is_training
=
True
,
drop_remainder
=
Tru
e
):
input_pipeline_context
=
Non
e
):
"""Creates input dataset from (tf)records files for train/eval."""
"""Creates input dataset from (tf)records files for train/eval."""
name_to_features
=
{
name_to_features
=
{
'input_ids'
:
tf
.
io
.
FixedLenFeature
([
seq_length
],
tf
.
int64
),
'input_ids'
:
tf
.
io
.
FixedLenFeature
([
seq_length
],
tf
.
int64
),
...
@@ -151,8 +146,13 @@ def create_classifier_dataset(file_path,
...
@@ -151,8 +146,13 @@ def create_classifier_dataset(file_path,
'label_ids'
:
tf
.
io
.
FixedLenFeature
([],
tf
.
int64
),
'label_ids'
:
tf
.
io
.
FixedLenFeature
([],
tf
.
int64
),
'is_real_example'
:
tf
.
io
.
FixedLenFeature
([],
tf
.
int64
),
'is_real_example'
:
tf
.
io
.
FixedLenFeature
([],
tf
.
int64
),
}
}
input_fn
=
file_based_input_fn_builder
(
file_path
,
name_to_features
)
dataset
=
single_file_dataset
(
file_path
,
name_to_features
)
dataset
=
input_fn
()
# The dataset is always sharded by number of hosts.
# num_input_pipelines is the number of hosts rather than number of cores.
if
input_pipeline_context
and
input_pipeline_context
.
num_input_pipelines
>
1
:
dataset
=
dataset
.
shard
(
input_pipeline_context
.
num_input_pipelines
,
input_pipeline_context
.
input_pipeline_id
)
def
_select_data_from_record
(
record
):
def
_select_data_from_record
(
record
):
x
=
{
x
=
{
...
@@ -169,12 +169,16 @@ def create_classifier_dataset(file_path,
...
@@ -169,12 +169,16 @@ def create_classifier_dataset(file_path,
dataset
=
dataset
.
shuffle
(
100
)
dataset
=
dataset
.
shuffle
(
100
)
dataset
=
dataset
.
repeat
()
dataset
=
dataset
.
repeat
()
dataset
=
dataset
.
batch
(
batch_size
,
drop_remainder
=
drop_remainder
)
dataset
=
dataset
.
batch
(
batch_size
,
drop_remainder
=
is_training
)
dataset
=
dataset
.
prefetch
(
1024
)
dataset
=
dataset
.
prefetch
(
1024
)
return
dataset
return
dataset
def
create_squad_dataset
(
file_path
,
seq_length
,
batch_size
,
is_training
=
True
):
def
create_squad_dataset
(
file_path
,
seq_length
,
batch_size
,
is_training
=
True
,
input_pipeline_context
=
None
):
"""Creates input dataset from (tf)records files for train/eval."""
"""Creates input dataset from (tf)records files for train/eval."""
name_to_features
=
{
name_to_features
=
{
'input_ids'
:
tf
.
io
.
FixedLenFeature
([
seq_length
],
tf
.
int64
),
'input_ids'
:
tf
.
io
.
FixedLenFeature
([
seq_length
],
tf
.
int64
),
...
@@ -187,8 +191,13 @@ def create_squad_dataset(file_path, seq_length, batch_size, is_training=True):
...
@@ -187,8 +191,13 @@ def create_squad_dataset(file_path, seq_length, batch_size, is_training=True):
else
:
else
:
name_to_features
[
'unique_ids'
]
=
tf
.
io
.
FixedLenFeature
([],
tf
.
int64
)
name_to_features
[
'unique_ids'
]
=
tf
.
io
.
FixedLenFeature
([],
tf
.
int64
)
input_fn
=
file_based_input_fn_builder
(
file_path
,
name_to_features
)
dataset
=
single_file_dataset
(
file_path
,
name_to_features
)
dataset
=
input_fn
()
# The dataset is always sharded by number of hosts.
# num_input_pipelines is the number of hosts rather than number of cores.
if
input_pipeline_context
and
input_pipeline_context
.
num_input_pipelines
>
1
:
dataset
=
dataset
.
shard
(
input_pipeline_context
.
num_input_pipelines
,
input_pipeline_context
.
input_pipeline_id
)
def
_select_data_from_record
(
record
):
def
_select_data_from_record
(
record
):
"""Dispatches record to features and labels."""
"""Dispatches record to features and labels."""
...
...
official/nlp/bert/run_classifier.py
View file @
81d031d0
...
@@ -18,7 +18,6 @@ from __future__ import absolute_import
...
@@ -18,7 +18,6 @@ from __future__ import absolute_import
from
__future__
import
division
from
__future__
import
division
from
__future__
import
print_function
from
__future__
import
print_function
import
functools
import
json
import
json
import
math
import
math
import
os
import
os
...
@@ -80,6 +79,25 @@ def get_loss_fn(num_classes, loss_factor=1.0):
...
@@ -80,6 +79,25 @@ def get_loss_fn(num_classes, loss_factor=1.0):
return
classification_loss_fn
return
classification_loss_fn
def
get_dataset_fn
(
input_file_pattern
,
max_seq_length
,
global_batch_size
,
is_training
):
"""Gets a closure to create a dataset."""
def
_dataset_fn
(
ctx
=
None
):
"""Returns tf.data.Dataset for distributed BERT pretraining."""
batch_size
=
ctx
.
get_per_replica_batch_size
(
global_batch_size
)
if
ctx
else
global_batch_size
dataset
=
input_pipeline
.
create_classifier_dataset
(
input_file_pattern
,
max_seq_length
,
batch_size
,
is_training
=
is_training
,
input_pipeline_context
=
ctx
)
return
dataset
return
_dataset_fn
def
run_bert_classifier
(
strategy
,
def
run_bert_classifier
(
strategy
,
bert_config
,
bert_config
,
input_meta_data
,
input_meta_data
,
...
@@ -264,7 +282,10 @@ def export_classifier(model_export_path, input_meta_data,
...
@@ -264,7 +282,10 @@ def export_classifier(model_export_path, input_meta_data,
restore_model_using_load_weights
=
restore_model_using_load_weights
)
restore_model_using_load_weights
=
restore_model_using_load_weights
)
def
run_bert
(
strategy
,
input_meta_data
,
train_input_fn
,
eval_input_fn
):
def
run_bert
(
strategy
,
input_meta_data
,
train_input_fn
=
None
,
eval_input_fn
=
None
):
"""Run BERT training."""
"""Run BERT training."""
bert_config
=
modeling
.
BertConfig
.
from_json_file
(
FLAGS
.
bert_config_file
)
bert_config
=
modeling
.
BertConfig
.
from_json_file
(
FLAGS
.
bert_config_file
)
if
FLAGS
.
mode
==
'export_only'
:
if
FLAGS
.
mode
==
'export_only'
:
...
@@ -340,18 +361,17 @@ def main(_):
...
@@ -340,18 +361,17 @@ def main(_):
FLAGS
.
strategy_type
)
FLAGS
.
strategy_type
)
max_seq_length
=
input_meta_data
[
'max_seq_length'
]
max_seq_length
=
input_meta_data
[
'max_seq_length'
]
train_input_fn
=
functools
.
partial
(
train_input_fn
=
get_dataset_fn
(
input_pipeline
.
create_classifier_dataset
,
FLAGS
.
train_data_path
,
FLAGS
.
train_data_path
,
seq_length
=
max_seq_length
,
max_seq_length
,
batch_size
=
FLAGS
.
train_batch_size
)
FLAGS
.
train_batch_size
,
eval_input_fn
=
functools
.
partial
(
is_training
=
True
)
input_
pipeline
.
create_classifier
_dataset
,
eval_
input_
fn
=
get
_dataset
_fn
(
FLAGS
.
eval_data_path
,
FLAGS
.
eval_data_path
,
seq_length
=
max_seq_length
,
max_seq_length
,
batch_size
=
FLAGS
.
eval_batch_size
,
FLAGS
.
eval_batch_size
,
is_training
=
False
,
is_training
=
False
)
drop_remainder
=
False
)
run_bert
(
strategy
,
input_meta_data
,
train_input_fn
,
eval_input_fn
)
run_bert
(
strategy
,
input_meta_data
,
train_input_fn
,
eval_input_fn
)
...
...
official/nlp/bert/run_pretraining.py
View file @
81d031d0
...
@@ -13,13 +13,10 @@
...
@@ -13,13 +13,10 @@
# limitations under the License.
# limitations under the License.
# ==============================================================================
# ==============================================================================
"""Run masked LM/next sentence masked_lm pre-training for BERT in tf2.0."""
"""Run masked LM/next sentence masked_lm pre-training for BERT in tf2.0."""
from
__future__
import
absolute_import
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
division
from
__future__
import
print_function
from
__future__
import
print_function
import
functools
from
absl
import
app
from
absl
import
app
from
absl
import
flags
from
absl
import
flags
from
absl
import
logging
from
absl
import
logging
...
@@ -56,31 +53,17 @@ common_flags.define_common_bert_flags()
...
@@ -56,31 +53,17 @@ common_flags.define_common_bert_flags()
FLAGS
=
flags
.
FLAGS
FLAGS
=
flags
.
FLAGS
def
get_pretrain_
input_data
(
input_file_pattern
,
seq_length
,
def
get_pretrain_
dataset_fn
(
input_file_pattern
,
seq_length
,
max_predictions_per_seq
,
batch_size
,
strategy
):
max_predictions_per_seq
,
global_
batch_size
):
"""Returns input dataset from input file string."""
"""Returns input dataset from input file string."""
# When using TPU pods, we need to clone dataset across
# workers and need to pass in function that returns the dataset rather
# than passing dataset instance itself.
use_dataset_fn
=
isinstance
(
strategy
,
tf
.
distribute
.
experimental
.
TPUStrategy
)
if
use_dataset_fn
:
if
batch_size
%
strategy
.
num_replicas_in_sync
!=
0
:
raise
ValueError
(
'Batch size must be divisible by number of replicas : {}'
.
format
(
strategy
.
num_replicas_in_sync
))
# As auto rebatching is not supported in
# `experimental_distribute_datasets_from_function()` API, which is
# required when cloning dataset to multiple workers in eager mode,
# we use per-replica batch size.
batch_size
=
int
(
batch_size
/
strategy
.
num_replicas_in_sync
)
def
_dataset_fn
(
ctx
=
None
):
def
_dataset_fn
(
ctx
=
None
):
"""Returns tf.data.Dataset for distributed BERT pretraining."""
"""Returns tf.data.Dataset for distributed BERT pretraining."""
input_patterns
=
input_file_pattern
.
split
(
','
)
input_files
=
[]
for
input_pattern
in
input_file_pattern
.
split
(
','
):
input_files
.
extend
(
tf
.
io
.
gfile
.
glob
(
input_pattern
))
batch_size
=
ctx
.
get_per_replica_batch_size
(
global_batch_size
)
train_dataset
=
input_pipeline
.
create_pretrain_dataset
(
train_dataset
=
input_pipeline
.
create_pretrain_dataset
(
input_
pattern
s
,
input_
file
s
,
seq_length
,
seq_length
,
max_predictions_per_seq
,
max_predictions_per_seq
,
batch_size
,
batch_size
,
...
@@ -88,7 +71,7 @@ def get_pretrain_input_data(input_file_pattern, seq_length,
...
@@ -88,7 +71,7 @@ def get_pretrain_input_data(input_file_pattern, seq_length,
input_pipeline_context
=
ctx
)
input_pipeline_context
=
ctx
)
return
train_dataset
return
train_dataset
return
_dataset_fn
if
use_dataset_fn
else
_dataset_fn
()
return
_dataset_fn
def
get_loss_fn
(
loss_factor
=
1.0
):
def
get_loss_fn
(
loss_factor
=
1.0
):
...
@@ -114,9 +97,9 @@ def run_customized_training(strategy,
...
@@ -114,9 +97,9 @@ def run_customized_training(strategy,
train_batch_size
):
train_batch_size
):
"""Run BERT pretrain model training using low-level API."""
"""Run BERT pretrain model training using low-level API."""
train_input_fn
=
functools
.
partial
(
get_pretrain_input_data
,
input_files
,
train_input_fn
=
get_pretrain_dataset_fn
(
input_files
,
max_seq_length
,
max_seq_length
,
max_predictions_per_seq
,
max_predictions_per_seq
,
train_batch_size
,
strategy
)
train_batch_size
)
def
_get_pretrain_model
():
def
_get_pretrain_model
():
"""Gets a pretraining model."""
"""Gets a pretraining model."""
...
...
official/nlp/bert/run_squad.py
View file @
81d031d0
...
@@ -18,7 +18,6 @@ from __future__ import absolute_import
...
@@ -18,7 +18,6 @@ from __future__ import absolute_import
from
__future__
import
division
from
__future__
import
division
from
__future__
import
print_function
from
__future__
import
print_function
import
functools
import
json
import
json
import
os
import
os
...
@@ -136,22 +135,44 @@ def get_raw_results(predictions):
...
@@ -136,22 +135,44 @@ def get_raw_results(predictions):
end_logits
=
values
[
2
].
tolist
())
end_logits
=
values
[
2
].
tolist
())
def
get_dataset_fn
(
input_file_pattern
,
max_seq_length
,
global_batch_size
,
is_training
):
"""Gets a closure to create a dataset.."""
def
_dataset_fn
(
ctx
=
None
):
"""Returns tf.data.Dataset for distributed BERT pretraining."""
batch_size
=
ctx
.
get_per_replica_batch_size
(
global_batch_size
)
if
ctx
else
global_batch_size
dataset
=
input_pipeline
.
create_squad_dataset
(
input_file_pattern
,
max_seq_length
,
batch_size
,
is_training
=
is_training
,
input_pipeline_context
=
ctx
)
return
dataset
return
_dataset_fn
def
predict_squad_customized
(
strategy
,
input_meta_data
,
bert_config
,
def
predict_squad_customized
(
strategy
,
input_meta_data
,
bert_config
,
predict_tfrecord_path
,
num_steps
):
predict_tfrecord_path
,
num_steps
):
"""Make predictions using a Bert-based squad model."""
"""Make predictions using a Bert-based squad model."""
predict_dataset
=
input_pipeline
.
create_squad
_dataset
(
predict_dataset
_fn
=
get
_dataset
_fn
(
predict_tfrecord_path
,
predict_tfrecord_path
,
input_meta_data
[
'max_seq_length'
],
input_meta_data
[
'max_seq_length'
],
FLAGS
.
predict_batch_size
,
FLAGS
.
predict_batch_size
,
is_training
=
False
)
is_training
=
False
)
predict_iterator
=
iter
(
predict_iterator
=
iter
(
strategy
.
experimental_distribute_dataset
(
predict_dataset
))
strategy
.
experimental_distribute_datasets_from_function
(
predict_dataset_fn
))
with
strategy
.
scope
():
with
strategy
.
scope
():
# Prediction always uses float32, even if training uses mixed precision.
# Prediction always uses float32, even if training uses mixed precision.
tf
.
keras
.
mixed_precision
.
experimental
.
set_policy
(
'float32'
)
tf
.
keras
.
mixed_precision
.
experimental
.
set_policy
(
'float32'
)
squad_model
,
_
=
bert_models
.
squad_model
(
squad_model
,
_
=
bert_models
.
squad_model
(
bert_config
,
input_meta_data
[
'max_seq_length'
],
float_type
=
tf
.
float32
,
bert_config
,
input_meta_data
[
'max_seq_length'
],
float_type
=
tf
.
float32
,
use_keras_bert
=
FLAGS
.
use_keras_bert_for_squad
)
use_keras_bert
=
FLAGS
.
use_keras_bert_for_squad
)
checkpoint_path
=
tf
.
train
.
latest_checkpoint
(
FLAGS
.
model_dir
)
checkpoint_path
=
tf
.
train
.
latest_checkpoint
(
FLAGS
.
model_dir
)
...
@@ -208,8 +229,7 @@ def train_squad(strategy,
...
@@ -208,8 +229,7 @@ def train_squad(strategy,
max_seq_length
=
input_meta_data
[
'max_seq_length'
]
max_seq_length
=
input_meta_data
[
'max_seq_length'
]
steps_per_epoch
=
int
(
num_train_examples
/
FLAGS
.
train_batch_size
)
steps_per_epoch
=
int
(
num_train_examples
/
FLAGS
.
train_batch_size
)
warmup_steps
=
int
(
epochs
*
num_train_examples
*
0.1
/
FLAGS
.
train_batch_size
)
warmup_steps
=
int
(
epochs
*
num_train_examples
*
0.1
/
FLAGS
.
train_batch_size
)
train_input_fn
=
functools
.
partial
(
train_input_fn
=
get_dataset_fn
(
input_pipeline
.
create_squad_dataset
,
FLAGS
.
train_data_path
,
FLAGS
.
train_data_path
,
max_seq_length
,
max_seq_length
,
FLAGS
.
train_batch_size
,
FLAGS
.
train_batch_size
,
...
@@ -347,7 +367,9 @@ def export_squad(model_export_path, input_meta_data):
...
@@ -347,7 +367,9 @@ def export_squad(model_export_path, input_meta_data):
bert_config
=
modeling
.
BertConfig
.
from_json_file
(
FLAGS
.
bert_config_file
)
bert_config
=
modeling
.
BertConfig
.
from_json_file
(
FLAGS
.
bert_config_file
)
squad_model
,
_
=
bert_models
.
squad_model
(
squad_model
,
_
=
bert_models
.
squad_model
(
bert_config
,
input_meta_data
[
'max_seq_length'
],
float_type
=
tf
.
float32
,
bert_config
,
input_meta_data
[
'max_seq_length'
],
float_type
=
tf
.
float32
,
use_keras_bert
=
FLAGS
.
use_keras_bert_for_squad
)
use_keras_bert
=
FLAGS
.
use_keras_bert_for_squad
)
model_saving_utils
.
export_bert_model
(
model_saving_utils
.
export_bert_model
(
model_export_path
,
model
=
squad_model
,
checkpoint_dir
=
FLAGS
.
model_dir
)
model_export_path
,
model
=
squad_model
,
checkpoint_dir
=
FLAGS
.
model_dir
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment