Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
dcuai
dlexamples
Commits
cb8dde1c
Commit
cb8dde1c
authored
Jul 14, 2022
by
hepj
Browse files
增加transformer-xl模型代码
parent
a22e7ca7
Changes
51
Show whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
5840 additions
and
0 deletions
+5840
-0
TensorFlow/NLP/transformer-xl-master/sota/enwik8.sh
TensorFlow/NLP/transformer-xl-master/sota/enwik8.sh
+58
-0
TensorFlow/NLP/transformer-xl-master/sota/lm1b.sh
TensorFlow/NLP/transformer-xl-master/sota/lm1b.sh
+63
-0
TensorFlow/NLP/transformer-xl-master/sota/text8.sh
TensorFlow/NLP/transformer-xl-master/sota/text8.sh
+58
-0
TensorFlow/NLP/transformer-xl-master/sota/wt103.sh
TensorFlow/NLP/transformer-xl-master/sota/wt103.sh
+71
-0
TensorFlow/NLP/transformer-xl-master/test.sh
TensorFlow/NLP/transformer-xl-master/test.sh
+3
-0
TensorFlow/NLP/transformer-xl-master/tpu_estimator.py
TensorFlow/NLP/transformer-xl-master/tpu_estimator.py
+3519
-0
TensorFlow/NLP/transformer-xl-master/train.py
TensorFlow/NLP/transformer-xl-master/train.py
+462
-0
TensorFlow/NLP/transformer-xl-master/train_fp16.py
TensorFlow/NLP/transformer-xl-master/train_fp16.py
+480
-0
TensorFlow/NLP/transformer-xl-master/train_gpu.py
TensorFlow/NLP/transformer-xl-master/train_gpu.py
+480
-0
TensorFlow/NLP/transformer-xl-master/train_gpu.py_old
TensorFlow/NLP/transformer-xl-master/train_gpu.py_old
+476
-0
TensorFlow/NLP/transformer-xl-master/vocabulary.py
TensorFlow/NLP/transformer-xl-master/vocabulary.py
+170
-0
No files found.
TensorFlow/NLP/transformer-xl-master/sota/enwik8.sh
0 → 100644
View file @
cb8dde1c
#!/bin/bash
# Data
DATA_ROOT
=
./
DATA_DIR
=
${
DATA_ROOT
}
/pretrained_xl/tf_enwik8/data
MODEL_DIR
=
${
DATA_ROOT
}
/pretrained_xl/tf_enwik8/model
# Model
N_LAYER
=
24
D_MODEL
=
1024
D_EMBED
=
1024
N_HEAD
=
8
D_HEAD
=
128
D_INNER
=
3072
# Testing
TEST_TGT_LEN
=
128
TEST_MEM_LEN
=
3800
TEST_CLAMP_LEN
=
1000
TEST_CKPT_PATH
=
${
MODEL_DIR
}
/model.ckpt-0
TEST_BSZ
=
16
TEST_NUM_CORE
=
2
echo
'Preprocess test set...'
python data_utils.py
\
--data_dir
=
${
DATA_DIR
}
/
\
--dataset
=
enwik8
\
--tgt_len
=
${
TEST_TGT_LEN
}
\
--per_host_test_bsz
=
${
TEST_BSZ
}
\
--num_passes
=
1
\
--use_tpu
=
False
echo
'Run evaluation on test set...'
python train_gpu.py
\
--data_dir
=
${
DATA_DIR
}
/tfrecords
\
--record_info_dir
=
${
DATA_DIR
}
/tfrecords/
\
--corpus_info_path
=
${
DATA_DIR
}
/corpus-info.json
\
--eval_ckpt_path
=
${
TEST_CKPT_PATH
}
\
--model_dir
=
EXP-enwik8
\
--n_layer
=
${
N_LAYER
}
\
--d_model
=
${
D_MODEL
}
\
--d_embed
=
${
D_EMBED
}
\
--n_head
=
${
N_HEAD
}
\
--d_head
=
${
D_HEAD
}
\
--d_inner
=
${
D_INNER
}
\
--dropout
=
0.0
\
--dropatt
=
0.0
\
--tgt_len
=
${
TEST_TGT_LEN
}
\
--mem_len
=
${
TEST_MEM_LEN
}
\
--clamp_len
=
${
TEST_CLAMP_LEN
}
\
--same_length
=
True
\
--eval_batch_size
=
${
TEST_BSZ
}
\
--num_core_per_host
=
${
TEST_NUM_CORE
}
\
--do_train
=
False
\
--do_eval
=
True
\
--eval_split
=
test
TensorFlow/NLP/transformer-xl-master/sota/lm1b.sh
0 → 100644
View file @
cb8dde1c
#!/bin/bash
# Data
DATA_ROOT
=
./
DATA_DIR
=
${
DATA_ROOT
}
/pretrained_xl/tf_lm1b/data
MODEL_DIR
=
${
DATA_ROOT
}
/pretrained_xl/tf_lm1b/model
# Model
DIV_VAL
=
4
N_LAYER
=
24
D_MODEL
=
1280
D_EMBED
=
1280
N_HEAD
=
16
D_HEAD
=
80
D_INNER
=
8192
# Testing
TEST_TGT_LEN
=
32
TEST_MEM_LEN
=
128
TEST_CLAMP_LEN
=
-1
TEST_CKPT_PATH
=
${
MODEL_DIR
}
/model.ckpt-1191000
TEST_BSZ
=
16
TEST_NUM_CORE
=
1
echo
'Preprocess test set...'
python data_utils.py
\
--data_dir
=
${
DATA_DIR
}
/
\
--dataset
=
lm1b
\
--tgt_len
=
${
TEST_TGT_LEN
}
\
--per_host_test_bsz
=
${
TEST_BSZ
}
\
--num_passes
=
1
\
--use_tpu
=
False
echo
'Run evaluation on test set...'
python train_gpu.py
\
--data_dir
=
${
DATA_DIR
}
/tfrecords
\
--record_info_dir
=
${
DATA_DIR
}
/tfrecords/
\
--corpus_info_path
=
${
DATA_DIR
}
/corpus-info.json
\
--eval_ckpt_path
=
${
TEST_CKPT_PATH
}
\
--model_dir
=
EXP-lm1b
\
--div_val
=
${
DIV_VAL
}
\
--untie_r
=
True
\
--proj_share_all_but_first
=
False
\
--proj_same_dim
=
False
\
--n_layer
=
${
N_LAYER
}
\
--d_model
=
${
D_MODEL
}
\
--d_embed
=
${
D_EMBED
}
\
--n_head
=
${
N_HEAD
}
\
--d_head
=
${
D_HEAD
}
\
--d_inner
=
${
D_INNER
}
\
--dropout
=
0.0
\
--dropatt
=
0.0
\
--tgt_len
=
${
TEST_TGT_LEN
}
\
--mem_len
=
${
TEST_MEM_LEN
}
\
--clamp_len
=
${
TEST_CLAMP_LEN
}
\
--same_length
=
True
\
--eval_batch_size
=
${
TEST_BSZ
}
\
--num_core_per_host
=
${
TEST_NUM_CORE
}
\
--do_train
=
False
\
--do_eval
=
True
\
--eval_split
=
test
TensorFlow/NLP/transformer-xl-master/sota/text8.sh
0 → 100644
View file @
cb8dde1c
#!/bin/bash
# Data
DATA_ROOT
=
./
DATA_DIR
=
${
DATA_ROOT
}
/pretrained_xl/tf_text8/data
MODEL_DIR
=
${
DATA_ROOT
}
/pretrained_xl/tf_text8/model
# Model
N_LAYER
=
24
D_MODEL
=
1024
D_EMBED
=
1024
N_HEAD
=
8
D_HEAD
=
128
D_INNER
=
3072
# Testing
TEST_TGT_LEN
=
128
TEST_MEM_LEN
=
3800
TEST_CLAMP_LEN
=
1000
TEST_CKPT_PATH
=
${
MODEL_DIR
}
/model.ckpt-0
TEST_BSZ
=
16
TEST_NUM_CORE
=
2
echo
'Preprocess test set...'
python data_utils.py
\
--data_dir
=
${
DATA_DIR
}
/
\
--dataset
=
text8
\
--tgt_len
=
${
TEST_TGT_LEN
}
\
--per_host_test_bsz
=
${
TEST_BSZ
}
\
--num_passes
=
1
\
--use_tpu
=
False
echo
'Run evaluation on test set...'
python train_gpu.py
\
--data_dir
=
${
DATA_DIR
}
/tfrecords
\
--record_info_dir
=
${
DATA_DIR
}
/tfrecords/
\
--corpus_info_path
=
${
DATA_DIR
}
/corpus-info.json
\
--eval_ckpt_path
=
${
TEST_CKPT_PATH
}
\
--model_dir
=
EXP-text8
\
--n_layer
=
${
N_LAYER
}
\
--d_model
=
${
D_MODEL
}
\
--d_embed
=
${
D_EMBED
}
\
--n_head
=
${
N_HEAD
}
\
--d_head
=
${
D_HEAD
}
\
--d_inner
=
${
D_INNER
}
\
--dropout
=
0.0
\
--dropatt
=
0.0
\
--tgt_len
=
${
TEST_TGT_LEN
}
\
--mem_len
=
${
TEST_MEM_LEN
}
\
--clamp_len
=
${
TEST_CLAMP_LEN
}
\
--same_length
=
True
\
--eval_batch_size
=
${
TEST_BSZ
}
\
--num_core_per_host
=
${
TEST_NUM_CORE
}
\
--do_train
=
False
\
--do_eval
=
True
\
--eval_split
=
test
TensorFlow/NLP/transformer-xl-master/sota/wt103.sh
0 → 100644
View file @
cb8dde1c
#!/bin/bash
# Data
DATA_ROOT
=
./
DATA_DIR
=
${
DATA_ROOT
}
/pretrained_xl/tf_wt103/data
MODEL_DIR
=
${
DATA_ROOT
}
/pretrained_xl/tf_wt103/model
# Model
DIV_VAL
=
4
N_LAYER
=
18
D_MODEL
=
1024
D_EMBED
=
1024
N_HEAD
=
16
D_HEAD
=
64
D_INNER
=
4096
# Training
TGT_LEN
=
256
MEM_LEN
=
256
BSZ
=
16
NUM_CORE
=
2
# Testing
TEST_TGT_LEN
=
128
TEST_MEM_LEN
=
1600
TEST_CLAMP_LEN
=
1000
TEST_CKPT_PATH
=
${
MODEL_DIR
}
/model.ckpt-0
TEST_BSZ
=
16
TEST_NUM_CORE
=
1
echo
'Preprocess test set...'
python data_utils.py
\
--data_dir
=
${
DATA_DIR
}
/
\
--dataset
=
enwik8
\
--tgt_len
=
${
TEST_TGT_LEN
}
\
--per_host_test_bsz
=
${
TEST_BSZ
}
\
--num_passes
=
1
\
--use_tpu
=
False
echo
'Run evaluation on test set...'
python train_gpu.py
\
--data_dir
=
${
DATA_DIR
}
/tfrecords
\
--record_info_dir
=
${
DATA_DIR
}
/tfrecords/
\
--corpus_info_path
=
${
DATA_DIR
}
/corpus-info.json
\
--eval_ckpt_path
=
${
TEST_CKPT_PATH
}
\
--model_dir
=
EXP-wt103
\
--div_val
=
${
DIV_VAL
}
\
--untie_r
=
True
\
--proj_share_all_but_first
=
True
\
--n_layer
=
${
N_LAYER
}
\
--d_model
=
${
D_MODEL
}
\
--d_embed
=
${
D_EMBED
}
\
--n_head
=
${
N_HEAD
}
\
--d_head
=
${
D_HEAD
}
\
--d_inner
=
${
D_INNER
}
\
--dropout
=
0.0
\
--dropatt
=
0.0
\
--tgt_len
=
${
TEST_TGT_LEN
}
\
--mem_len
=
${
TEST_MEM_LEN
}
\
--clamp_len
=
${
TEST_CLAMP_LEN
}
\
--same_length
=
True
\
--eval_batch_size
=
${
TEST_BSZ
}
\
--num_core_per_host
=
${
TEST_NUM_CORE
}
\
--do_train
=
False
\
--do_eval
=
True
\
--eval_split
=
test
TensorFlow/NLP/transformer-xl-master/test.sh
0 → 100644
View file @
cb8dde1c
#!/bin/bash
./scripts/enwik8_base_test.sh train
#mpirun --allow-run-as-root -np 3 scripts/enwik8_base_gpu.sh train
TensorFlow/NLP/transformer-xl-master/tpu_estimator.py
0 → 100644
View file @
cb8dde1c
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===================================================================
"""TPUEstimator class."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
collections
import
copy
import
os
import
signal
import
sys
import
threading
import
time
import
numpy
as
np
import
six
from
six.moves
import
queue
as
Queue
# pylint: disable=redefined-builtin
from
six.moves
import
xrange
# pylint: disable=redefined-builtin
import
math
try
:
import
google3
from
google3.third_party.tensorflow.contrib.tpu.python.ops
import
tpu_ops
from
google3.third_party.tensorflow.contrib.tpu.python.tpu
import
error_handling
from
google3.third_party.tensorflow.contrib.tpu.python.tpu
import
session_support
from
google3.third_party.tensorflow.contrib.tpu.python.tpu
import
tpu
from
google3.third_party.tensorflow.contrib.tpu.python.tpu
import
tpu_config
from
google3.third_party.tensorflow.contrib.tpu.python.tpu
import
tpu_context
from
google3.third_party.tensorflow.contrib.tpu.python.tpu
import
tpu_feed
from
google3.third_party.tensorflow.contrib.tpu.python.tpu
import
training_loop
from
google3.third_party.tensorflow.contrib.tpu.python.tpu
import
util
as
util_lib
from
google3.third_party.tensorflow.contrib.training.python.training
import
hparam
from
google3.third_party.tensorflow.core.framework
import
variable_pb2
from
google3.third_party.tensorflow.core.framework.summary_pb2
import
Summary
from
google3.third_party.tensorflow.core.protobuf
import
config_pb2
from
google3.third_party.tensorflow.python.data.ops
import
dataset_ops
from
google3.third_party.tensorflow.python.data.util
import
nest
as
data_nest
from
google3.third_party.tensorflow.python.estimator
import
estimator
as
estimator_lib
from
google3.third_party.tensorflow.python.estimator
import
model_fn
as
model_fn_lib
from
google3.third_party.tensorflow.python.estimator.export
import
export_output
as
export_output_lib
from
google3.third_party.tensorflow.python.framework
import
constant_op
from
google3.third_party.tensorflow.python.framework
import
dtypes
from
google3.third_party.tensorflow.python.framework
import
errors
from
google3.third_party.tensorflow.python.framework
import
ops
from
google3.third_party.tensorflow.python.ops
import
array_ops
from
google3.third_party.tensorflow.python.ops
import
check_ops
from
google3.third_party.tensorflow.python.ops
import
control_flow_ops
from
google3.third_party.tensorflow.python.ops
import
init_ops
from
google3.third_party.tensorflow.python.ops
import
math_ops
from
google3.third_party.tensorflow.python.ops
import
resource_variable_ops
from
google3.third_party.tensorflow.python.ops
import
state_ops
from
google3.third_party.tensorflow.python.ops
import
summary_ops_v2
as
contrib_summary
from
google3.third_party.tensorflow.python.ops
import
variable_scope
from
google3.third_party.tensorflow.python.ops
import
variables
from
google3.third_party.tensorflow.python.platform
import
tf_logging
as
logging
from
google3.third_party.tensorflow.python.saved_model
import
tag_constants
from
google3.third_party.tensorflow.python.summary
import
summary
from
google3.third_party.tensorflow.python.training
import
basic_session_run_hooks
from
google3.third_party.tensorflow.python.training
import
evaluation
from
google3.third_party.tensorflow.python.training
import
session_run_hook
from
google3.third_party.tensorflow.python.training
import
training
from
google3.third_party.tensorflow.python.training
import
training_util
from
google3.third_party.tensorflow.python.util
import
function_utils
from
google3.third_party.tensorflow.python.util
import
nest
from
google3.third_party.tensorflow.python.util
import
tf_inspect
except
:
import
tensorflow
from
tensorflow.contrib.tpu.python.ops
import
tpu_ops
from
tensorflow.contrib.tpu.python.tpu
import
error_handling
from
tensorflow.contrib.tpu.python.tpu
import
session_support
from
tensorflow.contrib.tpu.python.tpu
import
tpu
from
tensorflow.contrib.tpu.python.tpu
import
tpu_config
from
tensorflow.contrib.tpu.python.tpu
import
tpu_context
from
tensorflow.contrib.tpu.python.tpu
import
tpu_feed
from
tensorflow.contrib.tpu.python.tpu
import
training_loop
from
tensorflow.contrib.tpu.python.tpu
import
util
as
util_lib
from
tensorflow.contrib.training.python.training
import
hparam
from
tensorflow.core.framework
import
variable_pb2
from
tensorflow.core.framework.summary_pb2
import
Summary
from
tensorflow.core.protobuf
import
config_pb2
from
tensorflow.python.data.ops
import
dataset_ops
from
tensorflow.python.data.util
import
nest
as
data_nest
from
tensorflow.python.estimator
import
estimator
as
estimator_lib
from
tensorflow.python.estimator
import
model_fn
as
model_fn_lib
from
tensorflow.python.estimator
import
util
as
estimator_util
from
tensorflow.python.estimator.export
import
export_output
as
export_output_lib
from
tensorflow.python.framework
import
constant_op
from
tensorflow.python.framework
import
dtypes
from
tensorflow.python.framework
import
errors
from
tensorflow.python.framework
import
ops
from
tensorflow.python.ops
import
array_ops
from
tensorflow.python.ops
import
check_ops
from
tensorflow.python.ops
import
control_flow_ops
from
tensorflow.python.ops
import
init_ops
from
tensorflow.python.ops
import
math_ops
from
tensorflow.python.ops
import
resource_variable_ops
from
tensorflow.python.ops
import
state_ops
from
tensorflow.python.ops
import
summary_ops_v2
as
contrib_summary
from
tensorflow.python.ops
import
variable_scope
from
tensorflow.python.ops
import
variables
from
tensorflow.python.platform
import
tf_logging
as
logging
from
tensorflow.python.saved_model
import
tag_constants
from
tensorflow.python.summary
import
summary
from
tensorflow.python.training
import
basic_session_run_hooks
from
tensorflow.python.training
import
evaluation
from
tensorflow.python.training
import
session_run_hook
from
tensorflow.python.training
import
training
from
tensorflow.python.training
import
training_util
from
tensorflow.python.util
import
function_utils
from
tensorflow.python.util
import
nest
from
tensorflow.python.util
import
tf_inspect
_INITIAL_LOSS
=
1e7
_ZERO_LOSS
=
0.
_TPU_ESTIMATOR
=
'custom_tpu_estimator'
# CHANGE FOR RECURRENCY
_ITERATIONS_PER_LOOP_VAR
=
'iterations_per_loop'
_BATCH_SIZE_KEY
=
'batch_size'
_CTX_KEY
=
'context'
_USE_TPU_KEY
=
'use_tpu'
_CROSS_REPLICA_SUM_OP
=
'CrossReplicaSum'
_ONE_GIGABYTE
=
1024
*
1024
*
1024
_TPU_ENQUEUE_OPS
=
'_tpu_enqueue_ops'
_TPU_TRAIN_OP
=
'_tpu_train_op'
_REWRITE_FOR_INFERENCE_MODE
=
'_rewrite_for_inference'
# Ideally _USE_TPU_KEY should be reserved as well. However there are already
# models that make use of this key, thus it can not be reserved now to prevent
# breakage. In the long run, we would like to mitigate this by migrating models
# off of using _USE_TPU_KEY.
_RESERVED_PARAMS_KEYS
=
[
_BATCH_SIZE_KEY
,
_CTX_KEY
]
# TODO(b/65703635): Flip the value and remove all dead code. Currently, this is
# only used for per-core based deployments. For per-host based pipelines, if a
# user returns a Dataset instance it will be automatically wrapped in a
# tf.while_loop (This can be disabled by returning features and labels
# explicitly).
_WRAP_INPUT_FN_INTO_WHILE_LOOP
=
False
ops
.
register_proto_function
(
'{}_{}'
.
format
(
_TPU_ESTIMATOR
,
_ITERATIONS_PER_LOOP_VAR
),
proto_type
=
variable_pb2
.
VariableDef
,
to_proto
=
resource_variable_ops
.
_to_proto_fn
,
# pylint: disable=protected-access
from_proto
=
resource_variable_ops
.
_from_proto_fn
)
# pylint: disable=protected-access
def
_create_global_step
(
graph
):
graph
=
graph
or
ops
.
get_default_graph
()
if
training
.
get_global_step
(
graph
)
is
not
None
:
raise
ValueError
(
'"global_step" already exists.'
)
# Create in proper graph and base name_scope.
with
graph
.
as_default
()
as
g
,
g
.
name_scope
(
None
):
return
variable_scope
.
get_variable
(
ops
.
GraphKeys
.
GLOBAL_STEP
,
shape
=
[],
dtype
=
dtypes
.
int64
,
initializer
=
init_ops
.
zeros_initializer
(),
trainable
=
False
,
use_resource
=
True
,
collections
=
[
ops
.
GraphKeys
.
GLOBAL_VARIABLES
,
ops
.
GraphKeys
.
GLOBAL_STEP
])
def
_create_or_get_iterations_per_loop
():
"""Creates or gets the iterations_per_loop variable.
In TPUEstimator, the user provided computation, the model_fn, is wrapped
inside a tf.while_loop for peak performance. The iterations of the loop are
specified by this variable, which adjusts its value on the CPU after each TPU
program execution and before the next TPU execution.
The purpose of using a variable, rather then a constant, is to allow
TPUEstimator adapt the TPU training iterations according to the final steps
specified by users. For example, if the user sets the iterations_per_loop as 4
in TPUConfig and steps as 10 in TPUEstimator.train(), the iterations_per_loop
variable will have the following value before each TPU training.
- 1-th TPU execution: iterations_per_loop = 4
- 2-th TPU execution: iterations_per_loop = 4
- 3-th TPU execution: iterations_per_loop = 2
As model_fn increases the global step once per train_op invocation, the global
step is 10 after all TPU executions, matching the steps=10 inputs passed in by
users.
Returns:
A TF non-trainable resource variable.
Raises:
RuntimeError: If multi iterations_per_loop variables were found.
"""
graph
=
ops
.
get_default_graph
()
collection_name
=
'{}_{}'
.
format
(
_TPU_ESTIMATOR
,
_ITERATIONS_PER_LOOP_VAR
)
iter_vars
=
graph
.
get_collection
(
collection_name
)
if
len
(
iter_vars
)
==
1
:
return
iter_vars
[
0
]
elif
len
(
iter_vars
)
>
1
:
raise
RuntimeError
(
'Multiple iterations_per_loop_var in collection.'
)
with
ops
.
colocate_with
(
training_util
.
get_global_step
()):
with
variable_scope
.
variable_scope
(
_TPU_ESTIMATOR
,
reuse
=
variable_scope
.
AUTO_REUSE
):
return
variable_scope
.
get_variable
(
_ITERATIONS_PER_LOOP_VAR
,
initializer
=
init_ops
.
zeros_initializer
(),
shape
=
[],
dtype
=
dtypes
.
int32
,
trainable
=
False
,
collections
=
[
collection_name
,
ops
.
GraphKeys
.
LOCAL_VARIABLES
],
use_resource
=
True
)
def
_sync_variables_ops
():
# Gets the variables back from TPU nodes. This means the variables updated
# by TPU will now be *synced* to host memory.
return
[
array_ops
.
check_numerics
(
v
.
read_value
(),
'Gradient for %s is NaN'
%
v
.
name
).
op
for
v
in
variables
.
trainable_variables
()
]
def
_increase_eval_step_op
(
iterations_per_loop
):
"""Returns an op to increase the eval step for TPU evaluation.
Args:
iterations_per_loop: Tensor. The number of eval steps running in TPU
system before returning to CPU host for each `Session.run`.
Returns:
An operation
"""
eval_step
=
evaluation
.
_get_or_create_eval_step
()
# pylint: disable=protected-access
# Estimator evaluate increases 1 by default. So, we increase the difference.
return
state_ops
.
assign_add
(
eval_step
,
math_ops
.
cast
(
iterations_per_loop
-
1
,
dtype
=
eval_step
.
dtype
),
use_locking
=
True
)
def
_extract_key_names
(
tensor_or_dict
):
if
isinstance
(
tensor_or_dict
,
dict
):
return
sorted
(
tensor_or_dict
.
keys
())
return
[]
class
_SIGNAL
(
object
):
"""Signal used to control the thread of infeed/outfeed.
All preserved signals must be negative numbers. Positive numbers are used to
indicate the number of iterations for next training/evaluation loop.
"""
NEXT_BATCH
=
-
1
STOP
=
-
2
class
TPUEstimatorSpec
(
model_fn_lib
.
_TPUEstimatorSpec
):
# pylint: disable=protected-access
"""Ops and objects returned from a `model_fn` and passed to `TPUEstimator`.
See `EstimatorSpec` for `mode`, `predictions`, `loss`, `train_op`, and
`export_outputs`.
For evaluation, `eval_metrics `is a tuple of `metric_fn` and `tensors`, where
`metric_fn` runs on CPU to generate metrics and `tensors` represents the
`Tensor`s transferred from TPU system to CPU host and passed to `metric_fn`.
To be precise, TPU evaluation expects a slightly different signature from the
@{tf.estimator.Estimator}. While `EstimatorSpec.eval_metric_ops` expects a
dict, `TPUEstimatorSpec.eval_metrics` is a tuple of `metric_fn` and `tensors`.
The `tensors` could be a list of `Tensor`s or dict of names to `Tensor`s. The
`tensors` usually specify the model logits, which are transferred back from
TPU system to CPU host. All tensors must have be batch-major, i.e., the batch
size is the first dimension. Once all tensors are available at CPU host from
all shards, they are concatenated (on CPU) and passed as positional arguments
to the `metric_fn` if `tensors` is list or keyword arguments if `tensors` is
a dict. `metric_fn` takes the `tensors` and returns a dict from metric string
name to the result of calling a metric function, namely a `(metric_tensor,
update_op)` tuple. See `TPUEstimator` for MNIST example how to specify the
`eval_metrics`.
`scaffold_fn` is a function running on CPU to generate the `Scaffold`. This
function should not capture any Tensors in `model_fn`.
`host_call` is a tuple of a `function` and a list or dictionary of `tensors`
to pass to that function and returns a list of Tensors. `host_call` currently
works for train() and evaluate(). The Tensors returned by the function is
executed on the CPU on every step, so there is communication overhead when
sending tensors from TPU to CPU. To reduce the overhead, try reducing the
size of the tensors. The `tensors` are concatenated along their major (batch)
dimension, and so must be >= rank 1. The `host_call` is useful for writing
summaries with @{tf.contrib.summary.create_file_writer}.
"""
def
__new__
(
cls
,
mode
,
predictions
=
None
,
loss
=
None
,
train_op
=
None
,
eval_metrics
=
None
,
export_outputs
=
None
,
scaffold_fn
=
None
,
host_call
=
None
,
training_hooks
=
None
,
evaluation_hooks
=
None
,
prediction_hooks
=
None
):
"""Creates a validated `TPUEstimatorSpec` instance."""
host_calls
=
{}
if
eval_metrics
is
not
None
:
host_calls
[
'eval_metrics'
]
=
eval_metrics
if
host_call
is
not
None
:
host_calls
[
'host_call'
]
=
host_call
_OutfeedHostCall
.
validate
(
host_calls
)
training_hooks
=
list
(
training_hooks
or
[])
evaluation_hooks
=
list
(
evaluation_hooks
or
[])
prediction_hooks
=
list
(
prediction_hooks
or
[])
for
hook
in
training_hooks
+
evaluation_hooks
+
prediction_hooks
:
if
not
isinstance
(
hook
,
session_run_hook
.
SessionRunHook
):
raise
TypeError
(
'All hooks must be SessionRunHook instances, given: {}'
.
format
(
hook
))
return
super
(
TPUEstimatorSpec
,
cls
).
__new__
(
cls
,
mode
=
mode
,
predictions
=
predictions
,
loss
=
loss
,
train_op
=
train_op
,
eval_metrics
=
eval_metrics
,
export_outputs
=
export_outputs
,
scaffold_fn
=
scaffold_fn
,
host_call
=
host_call
,
training_hooks
=
training_hooks
,
evaluation_hooks
=
evaluation_hooks
,
prediction_hooks
=
prediction_hooks
)
def
as_estimator_spec
(
self
):
"""Creates an equivalent `EstimatorSpec` used by CPU train/eval."""
host_calls
=
{}
if
self
.
eval_metrics
is
not
None
:
host_calls
[
'eval_metrics'
]
=
self
.
eval_metrics
if
self
.
host_call
is
not
None
:
host_calls
[
'host_call'
]
=
self
.
host_call
host_call_ret
=
_OutfeedHostCall
.
create_cpu_hostcall
(
host_calls
)
eval_metric_ops
=
None
if
self
.
eval_metrics
is
not
None
:
eval_metric_ops
=
host_call_ret
[
'eval_metrics'
]
hooks
=
None
if
self
.
host_call
is
not
None
:
hooks
=
[
_OutfeedHostCallHook
(
host_call_ret
[
'host_call'
])]
hooks
=
list
(
hooks
or
[])
scaffold
=
self
.
scaffold_fn
()
if
self
.
scaffold_fn
else
None
return
model_fn_lib
.
EstimatorSpec
(
mode
=
self
.
mode
,
predictions
=
self
.
predictions
,
loss
=
self
.
loss
,
train_op
=
self
.
train_op
,
eval_metric_ops
=
eval_metric_ops
,
export_outputs
=
self
.
export_outputs
,
scaffold
=
scaffold
,
training_hooks
=
self
.
training_hooks
+
hooks
,
evaluation_hooks
=
self
.
evaluation_hooks
+
hooks
,
prediction_hooks
=
self
.
prediction_hooks
+
hooks
)
class
_OpQueueContext
(
object
):
"""Manages work queue and thread for a infeed/outfeed thread."""
def
__init__
(
self
,
name
,
target
,
args
):
self
.
_name
=
name
self
.
_queue
=
Queue
.
Queue
()
args
=
(
self
,)
+
args
self
.
_thread
=
threading
.
Thread
(
name
=
name
,
target
=
target
,
args
=
args
)
self
.
_thread
.
daemon
=
True
self
.
_thread
.
start
()
def
stop
(
self
):
self
.
_queue
.
put
(
_SIGNAL
.
STOP
)
def
send_next_batch_signal
(
self
,
iterations
):
self
.
_queue
.
put
(
iterations
)
def
read_iteration_counts
(
self
):
while
True
:
iterations
=
self
.
_queue
.
get
(
block
=
True
)
logging
.
debug
(
'%s read iterations %s'
,
self
.
_name
,
iterations
)
if
iterations
==
_SIGNAL
.
STOP
:
logging
.
info
(
'%s received shutdown signal, stopping.'
,
self
.
_name
)
return
yield
iterations
def
join
(
self
):
logging
.
info
(
'Shutting down %s thread.'
%
self
.
_name
)
self
.
stop
()
self
.
_thread
.
join
()
class
_OpSignalOnceQueueContext
(
_OpQueueContext
):
"""Manages work queue and thread for a infeed/outfeed thread.
This subclass only signals once.
"""
def
__init__
(
self
,
name
,
target
,
args
):
super
(
_OpSignalOnceQueueContext
,
self
).
__init__
(
name
,
target
,
args
)
self
.
_has_signaled
=
False
def
send_next_batch_signal
(
self
,
iterations
):
if
not
self
.
_has_signaled
:
self
.
_queue
.
put
(
iterations
)
self
.
_has_signaled
=
True
class
TPUInfeedOutfeedSessionHook
(
session_run_hook
.
SessionRunHook
):
"""A Session hook setting up the TPU initialization, infeed, and outfeed.
This hook does two major things:
1. initialize and shutdown TPU system.
2. launch and join the threads for infeed enqueue and (optional) outfeed
dequeue.
"""
def
__init__
(
self
,
ctx
,
enqueue_ops
,
dequeue_ops
,
run_infeed_loop_on_coordinator
=
True
,
rendezvous
=
None
):
self
.
_master_job
=
ctx
.
master_job
self
.
_enqueue_ops
=
enqueue_ops
self
.
_dequeue_ops
=
dequeue_ops
self
.
_rendezvous
=
rendezvous
self
.
_run_infeed_loop_on_coordinator
=
run_infeed_loop_on_coordinator
self
.
_initial_infeed_sleep_secs
=
(
ctx
.
config
.
tpu_config
.
initial_infeed_sleep_secs
)
self
.
_feed_error
=
None
self
.
_finished
=
False
def
begin
(
self
):
logging
.
info
(
'TPU job name %s'
,
self
.
_master_job
)
self
.
_iterations_per_loop_var
=
_create_or_get_iterations_per_loop
()
self
.
_init_ops
=
[
tpu
.
initialize_system
(
job
=
self
.
_master_job
)]
self
.
_finalize_ops
=
[
tpu
.
shutdown_system
(
job
=
self
.
_master_job
)]
summary_writer_init_ops
=
contrib_summary
.
summary_writer_initializer_op
()
self
.
_init_ops
.
extend
(
summary_writer_init_ops
)
# Get all the writer resources from the initializer, so we know what to
# flush.
for
op
in
summary_writer_init_ops
:
self
.
_finalize_ops
.
append
(
contrib_summary
.
flush
(
writer
=
op
.
inputs
[
0
]))
def
_run_infeed
(
self
,
queue_ctx
,
session
):
logging
.
info
(
'Starting infeed thread controller.'
)
if
self
.
_initial_infeed_sleep_secs
:
logging
.
info
(
'%s thread sleeping for %d seconds.'
,
self
.
_name
,
self
.
_initial_infeed_sleep_secs
)
time
.
sleep
(
self
.
_initial_infeed_sleep_secs
)
logging
.
info
(
'%s thread starting after sleep'
,
self
.
_name
)
with
self
.
_rendezvous
.
catch_errors
(
source
=
'infeed'
,
session
=
session
):
if
self
.
_run_infeed_loop_on_coordinator
:
for
count
,
steps
in
enumerate
(
queue_ctx
.
read_iteration_counts
()):
for
i
in
xrange
(
steps
):
logging
.
debug
(
'Infeed enqueue for iteration (%d, %d)'
,
count
,
i
)
session
.
run
(
self
.
_enqueue_ops
)
else
:
for
_
in
queue_ctx
.
read_iteration_counts
():
session
.
run
(
self
.
_enqueue_ops
)
logging
.
info
(
'Infeed thread finished, shutting down.'
)
def
_run_outfeed
(
self
,
queue_ctx
,
session
):
logging
.
info
(
'Starting outfeed thread controller.'
)
with
self
.
_rendezvous
.
catch_errors
(
source
=
'outfeed'
,
session
=
session
):
for
count
,
steps
in
enumerate
(
queue_ctx
.
read_iteration_counts
()):
for
i
in
xrange
(
steps
):
logging
.
debug
(
'Outfeed dequeue for iteration (%d, %d)'
,
count
,
i
)
session
.
run
(
self
.
_dequeue_ops
)
logging
.
info
(
'Outfeed thread finished, shutting down.'
)
def
_create_infeed_controller
(
self
,
name
,
target
,
args
):
return
_OpQueueContext
(
name
=
name
,
target
=
target
,
args
=
args
)
def
after_create_session
(
self
,
session
,
coord
):
logging
.
info
(
'Init TPU system'
)
session
.
run
(
self
.
_init_ops
,
options
=
config_pb2
.
RunOptions
(
timeout_in_ms
=
5
*
60
*
1000
))
self
.
_infeed_controller
=
self
.
_create_infeed_controller
(
name
=
'InfeedController'
,
target
=
self
.
_run_infeed
,
args
=
(
session
,))
self
.
_outfeed_controller
=
_OpQueueContext
(
name
=
'OutfeedController'
,
target
=
self
.
_run_outfeed
,
args
=
(
session
,))
def
before_run
(
self
,
run_context
):
self
.
_feed_error
=
None
iterations
=
run_context
.
session
.
run
(
self
.
_iterations_per_loop_var
)
logging
.
info
(
'Enqueue next (%d) batch(es) of data to infeed.'
,
iterations
)
self
.
_infeed_controller
.
send_next_batch_signal
(
iterations
)
logging
.
info
(
'Dequeue next (%d) batch(es) of data from outfeed.'
,
iterations
)
self
.
_outfeed_controller
.
send_next_batch_signal
(
iterations
)
def
end
(
self
,
session
):
self
.
_finished
=
True
logging
.
info
(
'Stop infeed thread controller'
)
self
.
_infeed_controller
.
join
()
self
.
_rendezvous
.
record_done
(
'infeed'
)
logging
.
info
(
'Stop output thread controller'
)
self
.
_outfeed_controller
.
join
()
self
.
_rendezvous
.
record_done
(
'outfeed'
)
logging
.
info
(
'Shutdown TPU system.'
)
session
.
run
(
self
.
_finalize_ops
)
class
TPUInfeedOutfeedSessionHookForPrediction
(
TPUInfeedOutfeedSessionHook
):
def
__init__
(
self
,
ctx
,
enqueue_ops
,
dequeue_ops
,
rendezvous
=
None
):
super
(
TPUInfeedOutfeedSessionHookForPrediction
,
self
).
__init__
(
ctx
,
enqueue_ops
,
dequeue_ops
,
run_infeed_loop_on_coordinator
=
False
,
rendezvous
=
rendezvous
)
def
_create_infeed_controller
(
self
,
name
,
target
,
args
):
return
_OpSignalOnceQueueContext
(
name
=
name
,
target
=
target
,
args
=
args
)
class
_TPUStopAtStepHook
(
session_run_hook
.
SessionRunHook
):
"""Hook that requests stop at a specified step.
This hook is similar to the `session_run_hook._StopAfterNEvalsHook` with
following differences for TPU training:
1. This hook sets the variable for iterations_per_loop, which is used by
`TPUInfeedOutfeedSessionHook` to control the iterations for infeed/outfeed.
As the hook execution order is not guaranteed, the variable update is
handled in `after_create_session` and `after_run` as
`TPUInfeedOutfeedSessionHook` reads the variable value in `before_run`.
2. For each training loop (session.run), the global step could be increased
multiple times on TPU. The global step tensor value will be explicitly read
again in `after_run` to ensure the latest value is retrieved to avoid race
condition.
"""
def
__init__
(
self
,
iterations
,
num_steps
=
None
,
last_step
=
None
):
"""Initializes a `StopAtStepHook`.
Args:
iterations: The number of iterations to run optimizer per training loop.
num_steps: Number of steps to execute.
last_step: Step after which to stop.
Raises:
ValueError: If one of the arguments is invalid.
"""
if
num_steps
is
None
and
last_step
is
None
:
raise
ValueError
(
'One of num_steps or last_step must be specified.'
)
if
num_steps
is
not
None
and
last_step
is
not
None
:
raise
ValueError
(
'Only one of num_steps or last_step can be specified.'
)
self
.
_num_steps
=
num_steps
self
.
_last_step
=
last_step
self
.
_iterations
=
iterations
def
_next_iterations
(
self
,
global_step
,
last_step
):
gap
=
last_step
-
global_step
return
min
(
gap
,
self
.
_iterations
)
def
begin
(
self
):
self
.
_global_step_tensor
=
training_util
.
get_global_step
()
if
self
.
_global_step_tensor
is
None
:
raise
RuntimeError
(
'Global step should be created.'
)
self
.
_iterations_per_loop_var
=
_create_or_get_iterations_per_loop
()
def
after_create_session
(
self
,
session
,
coord
):
global_step
=
session
.
run
(
self
.
_global_step_tensor
)
if
self
.
_last_step
is
None
:
self
.
_last_step
=
global_step
+
self
.
_num_steps
iterations
=
self
.
_next_iterations
(
global_step
,
self
.
_last_step
)
self
.
_iterations_per_loop_var
.
load
(
iterations
,
session
=
session
)
def
after_run
(
self
,
run_context
,
run_values
):
# Global step cannot be retrieved via SessionRunArgs and before_run due to
# race condition.
global_step
=
run_context
.
session
.
run
(
self
.
_global_step_tensor
)
if
global_step
>=
self
.
_last_step
:
run_context
.
request_stop
()
else
:
iterations
=
self
.
_next_iterations
(
global_step
,
self
.
_last_step
)
self
.
_iterations_per_loop_var
.
load
(
iterations
,
session
=
run_context
.
session
)
class
_SetEvalIterationsHook
(
session_run_hook
.
SessionRunHook
):
"""Hook that requests stop at a specified step."""
def
__init__
(
self
,
num_steps
):
"""Initializes a `_SetEvalIterationsHook`.
Args:
num_steps: Number of steps to execute.
"""
self
.
_num_steps
=
num_steps
def
begin
(
self
):
self
.
_iterations_per_loop_var
=
_create_or_get_iterations_per_loop
()
def
after_create_session
(
self
,
session
,
coord
):
self
.
_iterations_per_loop_var
.
load
(
self
.
_num_steps
,
session
=
session
)
class
_StoppingPredictHook
(
session_run_hook
.
SessionRunHook
):
"""Hook that requests stop according to the stopping signal in prediction."""
def
__init__
(
self
,
scalar_stopping_signal
):
self
.
_scalar_stopping_signal
=
scalar_stopping_signal
def
begin
(
self
):
self
.
_iterations_per_loop_var
=
_create_or_get_iterations_per_loop
()
def
after_create_session
(
self
,
session
,
coord
):
# This is not necessary as we do not run infeed enqueue and outfeed dequeue
# in side threads for prediction model. But it makes the
# TPUInfeedOutfeedSessionHook prints nice message.
self
.
_iterations_per_loop_var
.
load
(
1
,
session
=
session
)
def
before_run
(
self
,
run_context
):
return
session_run_hook
.
SessionRunArgs
(
self
.
_scalar_stopping_signal
)
def
after_run
(
self
,
run_context
,
run_values
):
_
=
run_context
scalar_stopping_signal
=
run_values
.
results
if
_StopSignals
.
should_stop
(
scalar_stopping_signal
):
# NOTE(xiejw): In prediction, stopping signals are inserted for each
# batch. And we append one more batch to signal the system it should stop.
# The data flow might look like
#
# batch 0: images, labels, stop = 0 (user provided)
# batch 1: images, labels, stop = 0 (user provided)
# ...
# batch 99: images, labels, stop = 0 (user provided)
# batch 100: images, labels, stop = 1 (TPUEstimator appended)
#
# where the final batch (id = 100) is appended by TPUEstimator, so we
# should drop it before returning the predictions to user.
# To achieve that, we throw the OutOfRangeError in after_run. Once
# Monitored Session sees this error in SessionRunHook.after_run, the
# "current" prediction, i.e., batch with id=100, will be discarded
# immediately
raise
errors
.
OutOfRangeError
(
None
,
None
,
'Stopped by stopping signal.'
)
def
generate_per_core_enqueue_ops_fn_for_host
(
ctx
,
input_fn
,
inputs_structure_recorder
,
host_device
,
host_id
):
"""Generates infeed enqueue ops for per-core input_fn on a single host."""
captured_infeed_queue
=
_CapturedObject
()
tpu_ordinal_function_impl
=
ctx
.
tpu_ordinal_function
(
host_id
)
def
enqueue_ops_fn
():
"""A fn returns enqueue_ops."""
num_cores_per_host
=
ctx
.
num_of_cores_per_host
per_host_sharded_inputs
=
[]
for
core_ordinal
in
range
(
num_cores_per_host
):
with
ops
.
name_scope
(
'ordinal_%d'
%
(
core_ordinal
)):
user_context
=
tpu_context
.
TPUContext
(
internal_ctx
=
ctx
,
input_device
=
host_device
,
invocation_index
=
host_id
*
ctx
.
num_of_cores_per_host
+
core_ordinal
)
inputs
=
_Inputs
.
from_input_fn
(
input_fn
(
user_context
))
if
inputs
.
is_dataset
:
raise
TypeError
(
'`input_fn` returning `Dataset` is not yet supported in '
'per-Core input pipeline deployment yet. Please set '
'TPUConfig.per_host_input_for_training to True or return '
'`features` and `labels` from `input_fn`'
)
features
,
labels
=
inputs
.
features_and_labels
()
inputs_structure_recorder
.
validate_and_record_structure
(
features
,
labels
)
flattened_inputs
=
(
inputs_structure_recorder
.
flatten_features_and_labels
(
features
,
labels
))
per_host_sharded_inputs
.
append
(
flattened_inputs
)
infeed_queue
=
tpu_feed
.
InfeedQueue
(
number_of_tuple_elements
=
len
(
per_host_sharded_inputs
[
0
]))
captured_infeed_queue
.
capture
(
infeed_queue
)
per_host_enqueue_ops
=
infeed_queue
.
generate_enqueue_ops
(
per_host_sharded_inputs
,
tpu_ordinal_function
=
tpu_ordinal_function_impl
)
return
per_host_enqueue_ops
return
enqueue_ops_fn
,
captured_infeed_queue
def
generate_per_host_enqueue_ops_fn_for_host
(
ctx
,
input_fn
,
inputs_structure_recorder
,
batch_axis
,
device
,
host_id
):
"""Generates infeed enqueue ops for per-host input_fn on a single host."""
captured_infeed_queue
=
_CapturedObject
()
hooks
=
[]
with
ops
.
device
(
device
):
user_context
=
tpu_context
.
TPUContext
(
internal_ctx
=
ctx
,
input_device
=
device
,
invocation_index
=
host_id
)
inputs
=
_Inputs
.
from_input_fn
(
input_fn
(
user_context
))
is_dataset
=
inputs
.
is_dataset
if
ctx
.
mode
==
model_fn_lib
.
ModeKeys
.
PREDICT
:
if
not
is_dataset
:
raise
TypeError
(
'For mode PREDICT, `input_fn` must return `Dataset` instead of '
'`features` and `labels`.'
)
if
batch_axis
is
not
None
:
raise
TypeError
(
'For mode PREDICT, batch_axis is not supported yet.'
)
inputs
=
_InputsWithStoppingSignals
(
dataset
=
inputs
.
dataset
,
batch_size
=
ctx
.
batch_size_for_input_fn
,
add_padding
=
True
)
if
is_dataset
:
hooks
.
append
(
inputs
.
dataset_initializer_hook
())
tpu_ordinal_function_impl
=
ctx
.
tpu_ordinal_function
(
host_id
)
def
enqueue_ops_fn
():
"""A Fn returning the TPU infeed enqueue ops.
By providing as a Fn, it can be invoked inside the tf.while_loop such that
the input pipeline for multiple iterations can be executed by one
Session.run call.
Returns:
list of dict of ops.
"""
with
ops
.
device
(
device
):
num_of_replicas_per_host
=
ctx
.
num_of_replicas_per_host
# Convert user input to features and labels. If the user returns a
# dataset, it is initialized and the features and labels extracted via
# `dataset.iterator.get_next()`
features
,
labels
=
inputs
.
features_and_labels
()
signals
=
inputs
.
signals
()
inputs_structure_recorder
.
validate_and_record_structure
(
features
,
labels
)
unsharded_tensor_list
=
(
inputs_structure_recorder
.
flatten_features_and_labels
(
features
,
labels
,
signals
))
infeed_queue
=
tpu_feed
.
InfeedQueue
(
tuple_types
=
[
t
.
dtype
for
t
in
unsharded_tensor_list
],
tuple_shapes
=
[
t
.
shape
for
t
in
unsharded_tensor_list
],
shard_dimensions
=
batch_axis
)
captured_infeed_queue
.
capture
(
infeed_queue
)
infeed_queue
.
set_number_of_shards
(
num_of_replicas_per_host
)
per_host_enqueue_ops
=
(
infeed_queue
.
split_inputs_and_generate_enqueue_ops
(
unsharded_tensor_list
,
placement_function
=
lambda
x
:
device
,
tpu_ordinal_function
=
tpu_ordinal_function_impl
))
if
signals
is
None
:
return
per_host_enqueue_ops
else
:
return
{
'ops'
:
per_host_enqueue_ops
,
'signals'
:
signals
,
}
return
enqueue_ops_fn
,
captured_infeed_queue
,
hooks
,
is_dataset
def
generate_per_host_v2_enqueue_ops_fn_for_host
(
ctx
,
input_fn
,
inputs_structure_recorder
,
device
,
host_id
):
"""Generates infeed enqueue ops for per-host input_fn on a single host."""
captured_infeed_queue
=
_CapturedObject
()
hooks
=
[]
with
ops
.
device
(
device
):
user_context
=
tpu_context
.
TPUContext
(
internal_ctx
=
ctx
,
input_device
=
device
,
invocation_index
=
host_id
)
inputs
=
_Inputs
.
from_input_fn
(
input_fn
(
user_context
))
is_dataset
=
inputs
.
is_dataset
if
not
is_dataset
:
raise
TypeError
(
'`input_fn` must return a `Dataset` for the PER_HOST_V2 '
'input pipeline configuration.'
)
if
ctx
.
mode
==
model_fn_lib
.
ModeKeys
.
PREDICT
:
inputs
=
_InputsWithStoppingSignals
(
dataset
=
inputs
.
dataset
,
batch_size
=
ctx
.
batch_size_for_input_fn
,
add_padding
=
True
,
num_invocations_per_step
=
ctx
.
num_of_replicas_per_host
)
hooks
.
append
(
inputs
.
dataset_initializer_hook
())
tpu_ordinal_function_impl
=
ctx
.
tpu_ordinal_function
(
host_id
)
def
enqueue_ops_fn
():
"""Generates the per_host enqueue ops."""
control_deps
=
[]
per_host_sharded_inputs
=
[]
num_replicas_per_host
=
ctx
.
num_of_replicas_per_host
cached_signals
=
None
with
ops
.
device
(
device
):
if
not
inputs
.
is_dataset
:
raise
TypeError
(
'`input_fn` must return a `Dataset` for this mode.'
)
for
_
in
range
(
num_replicas_per_host
):
# Use control dependencies to ensure a deterministic ordering.
with
ops
.
control_dependencies
(
control_deps
):
features
,
labels
=
inputs
.
features_and_labels
()
# Calls get_next()
signals
=
inputs
.
signals
()
# All the replicas share the replica 0's stopping singal.
# This avoids inconsistent state among different model replcias.
if
cached_signals
:
signals
[
'stopping'
]
=
cached_signals
[
'stopping'
]
else
:
cached_signals
=
signals
inputs_structure_recorder
.
validate_and_record_structure
(
features
,
labels
)
flattened_inputs
=
(
inputs_structure_recorder
.
flatten_features_and_labels
(
features
,
labels
,
signals
))
control_deps
.
extend
(
flattened_inputs
)
per_host_sharded_inputs
.
append
(
flattened_inputs
)
if
inputs_structure_recorder
.
flattened_input_dims
:
input_partition_dims
=
inputs_structure_recorder
.
flattened_input_dims
if
signals
:
input_partition_dims
+=
[
None
]
*
len
(
signals
)
# pylint: disable=protected-access
infeed_queue
=
tpu_feed
.
_PartitionedInfeedQueue
(
number_of_tuple_elements
=
len
(
per_host_sharded_inputs
[
0
]),
host_id
=
host_id
,
input_partition_dims
=
input_partition_dims
,
device_assignment
=
ctx
.
device_assignment
)
per_host_enqueue_ops
=
infeed_queue
.
generate_enqueue_ops
(
per_host_sharded_inputs
)
else
:
infeed_queue
=
tpu_feed
.
InfeedQueue
(
number_of_tuple_elements
=
len
(
per_host_sharded_inputs
[
0
]))
per_host_enqueue_ops
=
infeed_queue
.
generate_enqueue_ops
(
per_host_sharded_inputs
,
tpu_ordinal_function
=
tpu_ordinal_function_impl
)
captured_infeed_queue
.
capture
(
infeed_queue
)
if
signals
is
None
:
return
per_host_enqueue_ops
else
:
return
{
'ops'
:
per_host_enqueue_ops
,
'signals'
:
signals
,
}
return
enqueue_ops_fn
,
captured_infeed_queue
,
hooks
,
is_dataset
def
generate_broadcast_enqueue_ops_fn
(
ctx
,
input_fn
,
inputs_structure_recorder
,
num_hosts
):
"""Generates infeed enqueue ops for one input_fn on all the hosts."""
captured_infeed_queue
=
_CapturedObject
()
hooks
=
[]
device_0
=
ctx
.
tpu_host_placement_function
(
host_id
=
0
)
with
ops
.
device
(
device_0
):
user_context
=
tpu_context
.
TPUContext
(
internal_ctx
=
ctx
,
input_device
=
device_0
,
invocation_index
=
0
)
inputs
=
_Inputs
.
from_input_fn
(
input_fn
(
user_context
))
is_dataset
=
inputs
.
is_dataset
if
ctx
.
mode
==
model_fn_lib
.
ModeKeys
.
PREDICT
:
if
not
is_dataset
:
raise
TypeError
(
'For mode PREDICT, `input_fn` must return `Dataset` instead of '
'`features` and `labels`.'
)
inputs
=
_InputsWithStoppingSignals
(
dataset
=
inputs
.
dataset
,
batch_size
=
ctx
.
batch_size_for_input_fn
,
add_padding
=
True
)
if
is_dataset
:
hooks
.
append
(
inputs
.
dataset_initializer_hook
())
num_replicas_per_host
=
ctx
.
num_of_replicas_per_host
def
tpu_ordinal_function_impl
(
replica_id
):
if
ctx
.
device_assignment
:
return
ctx
.
device_assignment
.
tpu_ordinal
(
replica
=
replica_id
)
else
:
return
replica_id
%
num_replicas_per_host
def
device_function_impl
(
replica_id
):
return
ctx
.
tpu_host_placement_function
(
replica_id
=
replica_id
)
def
enqueue_ops_fn
():
"""Generates enqueue ops for all the hosts."""
broadcasted_inputs
=
[]
flattened_inputs
=
None
# Cache result from input_fn.
signals
=
None
for
host_id
in
xrange
(
num_hosts
):
with
ops
.
device
(
ctx
.
tpu_host_placement_function
(
host_id
=
host_id
)):
for
_
in
xrange
(
ctx
.
num_of_replicas_per_host
):
# Note: input_fn is only called once at host 0 for the first replica.
# The features and labels returned from that invocation are
# broadcasted to other replicas(including the replicas on other
# hosts).
if
flattened_inputs
is
None
:
features
,
labels
=
inputs
.
features_and_labels
()
# Calls get_next()
signals
=
inputs
.
signals
()
inputs_structure_recorder
.
validate_and_record_structure
(
features
,
labels
)
flattened_inputs
=
(
inputs_structure_recorder
.
flatten_features_and_labels
(
features
,
labels
,
signals
))
broadcasted_inputs
.
append
(
flattened_inputs
)
infeed_queue
=
tpu_feed
.
InfeedQueue
(
number_of_tuple_elements
=
len
(
broadcasted_inputs
[
0
]))
captured_infeed_queue
.
capture
(
infeed_queue
)
enqueue_ops
=
infeed_queue
.
generate_enqueue_ops
(
broadcasted_inputs
,
tpu_ordinal_function
=
tpu_ordinal_function_impl
,
placement_function
=
device_function_impl
)
if
signals
is
None
:
return
enqueue_ops
else
:
return
{
'ops'
:
enqueue_ops
,
'signals'
:
signals
,
}
return
enqueue_ops_fn
,
captured_infeed_queue
,
hooks
,
is_dataset
class
_InputPipeline
(
object
):
"""`_InputPipeline` handles invoking `input_fn` and piping to infeed queue.
`_InputPipeline` abstracts the per-core/per-host `input_fn` invocation from
call site. To be precise, based on the configuration in
`_InternalTPUContext`, it invokes `input_fn` for all cores (usually
multi-host TPU training) or for one host (usually for single-host TPU
evaluation), and sends all `features` and `labels` returned by `input_fn` to
TPU infeed. For per-core invocation, `features` and `labels` are piped to
infeed directly, one tuple for each core. For per-host invocation, `features`
and `labels` are split at host (with respect to `batch_axis`) and piped to all
cores accordingly.
In addition, flatten/unflatten are handled by `_InputPipeline` also. Model
inputs returned by the `input_fn` can have one of the following forms:
1. features
2. (features, labels)
3. ((arbitrarily nested structure of features), labels)
Internally, form 1 is reformed to `(features, None)` as features and labels
are passed separately to underlying methods. For TPU training, TPUEstimator
may expect multiple `features` and `labels` tuples one for each core.
TPUEstimator allows various different structures for inputs (namely `features`
and `labels`). `features` can be `Tensor`, dict of string name to `Tensor`,
or nested tuples and `labels` could be `None`, `Tensor`, or dict of string
name to `Tensor`. TPU infeed/outfeed library expects flattened tensor list.
So, `features` and `labels` need to be flattened, before infeed enqueue, and
the structure of them needs to be recorded, in order to restore them after
infeed dequeue.
"""
class
InputsStructureRecorder
(
object
):
"""The recorder to record inputs structure."""
def
__init__
(
self
,
input_partition_dims
=
None
):
# Holds the structure of inputs
self
.
_feature_structure
=
{}
self
.
_flattened_input_dims
=
None
if
input_partition_dims
:
# This should have been validated in TPUConfig.
assert
len
(
input_partition_dims
)
<=
2
,
'must have 1 or 2 elements.'
if
len
(
input_partition_dims
)
==
2
:
self
.
_feature_dims
,
self
.
_label_dims
=
input_partition_dims
else
:
self
.
_feature_dims
=
input_partition_dims
[
0
]
self
.
_label_dims
=
None
assert
self
.
_feature_dims
is
not
None
,
(
'input_partition_dims[0] must '
'not be None'
)
else
:
self
.
_feature_dims
=
None
self
.
_label_dims
=
None
# Internal state.
self
.
_initialized
=
False
@
property
def
flattened_input_dims
(
self
):
assert
self
.
_initialized
,
'InputsStructureRecorder is not initialized.'
return
self
.
_flattened_input_dims
def
has_labels
(
self
):
return
'labels'
in
self
.
_feature_structure
def
_flatten_input_dims
(
self
,
feature_dims
,
feature_dims_names
,
label_dims
,
label_dims_names
,
label_names
,
has_labels
):
"""Flatten input dims with the same order as flattened input tensors."""
flattened_input_dims
=
[]
if
feature_dims_names
:
# We need a fixed ordering for matching the tensors in features.
flattened_input_dims
.
extend
(
[
feature_dims
[
name
]
for
name
in
feature_dims_names
])
else
:
flattened_input_dims
.
append
(
feature_dims
)
if
label_dims_names
:
# We need a fixed ordering for matching the tensors in labels.
flattened_input_dims
.
extend
(
[
label_dims
[
name
]
for
name
in
label_dims_names
])
else
:
if
label_names
:
num_tensors_in_label
=
len
(
label_names
)
else
:
num_tensors_in_label
=
int
(
has_labels
)
# Setting `None` in input_partition_dims[1] will apply `None` to
# all the tensors in labels, regardless of internal structure.
flattened_input_dims
.
extend
([
label_dims
]
*
num_tensors_in_label
)
return
flattened_input_dims
def
validate_and_record_structure
(
self
,
features
,
labels
):
"""Validates and records the structure of `features` and `labels`."""
# Extract structure.
has_labels
=
labels
is
not
None
feature_names
=
_extract_key_names
(
features
)
label_names
=
_extract_key_names
(
labels
)
if
not
self
.
_initialized
:
# Record structure.
self
.
_initialized
=
True
if
self
.
_feature_dims
is
not
None
:
feature_dims_names
=
_extract_key_names
(
self
.
_feature_dims
)
if
feature_dims_names
!=
feature_names
:
raise
ValueError
(
'TPUConfig.input_partition_dims[0] mismatched feature'
' keys. Expected {}, got {}'
.
format
(
feature_names
,
feature_dims_names
))
label_dims_names
=
_extract_key_names
(
self
.
_label_dims
)
if
self
.
_label_dims
is
not
None
and
label_dims_names
!=
label_names
:
raise
ValueError
(
'TPUConfig.input_partition_dims[1] mismatched label'
' keys. Expected {}, got {}'
.
format
(
label_names
,
label_dims_names
))
self
.
_flattened_input_dims
=
self
.
_flatten_input_dims
(
self
.
_feature_dims
,
feature_dims_names
,
self
.
_label_dims
,
label_dims_names
,
label_names
,
has_labels
)
def
flatten_features_and_labels
(
self
,
features
,
labels
,
signals
=
None
):
"""Flattens the `features` and `labels` to a single tensor list."""
self
.
_feature_structure
[
'features'
]
=
features
if
labels
is
not
None
:
self
.
_feature_structure
[
'labels'
]
=
labels
if
signals
is
not
None
:
self
.
_feature_structure
[
'signals'
]
=
signals
return
data_nest
.
flatten
(
self
.
_feature_structure
)
def
unflatten_features_and_labels
(
self
,
flattened_inputs
):
"""Restores the flattened inputs to original features and labels form.
Args:
flattened_inputs: Flattened inputs for each shard.
Returns:
A tuple of (`features`, `labels`), where `labels` could be None.
Each one, if present, should have identical structure (single tensor vs
dict) as the one returned by input_fn.
Raises:
ValueError: If the number of expected tensors from `flattened_inputs`
mismatches the recorded structure.
"""
unflattened_inputs
=
data_nest
.
pack_sequence_as
(
self
.
_feature_structure
,
flattened_inputs
)
return
_Inputs
(
unflattened_inputs
[
'features'
],
unflattened_inputs
.
get
(
'labels'
),
signals
=
unflattened_inputs
.
get
(
'signals'
))
def
__init__
(
self
,
input_fn
,
batch_axis
,
ctx
):
"""Constructor.
Args:
input_fn: input fn for train or eval.
batch_axis: A python tuple of int values describing how each tensor
produced by the Estimator `input_fn` should be split across the TPU
compute shards.
ctx: A `_InternalTPUContext` instance with mode.
Raises:
ValueError: If both `sharded_features` and `num_cores` are `None`.
"""
self
.
_inputs_structure_recorder
=
_InputPipeline
.
InputsStructureRecorder
(
ctx
.
input_partition_dims
)
self
.
_sharded_per_core
=
ctx
.
is_input_sharded_per_core
()
self
.
_input_fn
=
input_fn
self
.
_infeed_queue
=
None
self
.
_ctx
=
ctx
self
.
_batch_axis
=
batch_axis
def
generate_infeed_enqueue_ops_and_dequeue_fn
(
self
):
"""Generates infeed enqueue ops and dequeue_fn."""
# While tf.while_loop is called, the body function, which invokes
# `enqueue_fn` passed in, is called to construct the graph. So, input_fn
# structure is recorded.
enqueue_ops
,
all_hooks
,
run_infeed_loop_on_coordinator
=
(
self
.
_invoke_input_fn_and_record_structure
())
self
.
_validate_input_pipeline
()
def
dequeue_fn
():
"""dequeue_fn is used by TPU to retrieve the tensors."""
# In the model-parallel case, both the host-side and device-side
# computations must agree on the core on which infeed takes place. We
# choose to perform infeed on logical core 0 of each replica.
values
=
self
.
_infeed_queue
.
generate_dequeue_op
(
tpu_device
=
0
)
# The unflatten process uses the structure information recorded above.
return
self
.
_inputs_structure_recorder
.
unflatten_features_and_labels
(
values
)
return
(
enqueue_ops
,
dequeue_fn
,
all_hooks
,
run_infeed_loop_on_coordinator
)
def
_invoke_input_fn_and_record_structure
(
self
):
"""Deploys the input pipeline and record input structure."""
enqueue_ops
=
[]
infeed_queues
=
[]
all_hooks
=
[]
num_hosts
=
self
.
_ctx
.
num_hosts
tpu_host_placement_fn
=
self
.
_ctx
.
tpu_host_placement_function
run_infeed_loop_on_coordinator
=
True
if
self
.
_sharded_per_core
:
# Per-Core input pipeline deployment.
# Invoke input pipeline for each core and placed on the corresponding
# host.
for
host_id
in
range
(
num_hosts
):
host_device
=
tpu_host_placement_fn
(
host_id
=
host_id
)
with
ops
.
device
(
host_device
):
with
ops
.
name_scope
(
'input_pipeline_task%d'
%
(
host_id
)):
enqueue_ops_fn
,
captured_infeed_queue
=
(
generate_per_core_enqueue_ops_fn_for_host
(
self
.
_ctx
,
self
.
_input_fn
,
self
.
_inputs_structure_recorder
,
host_device
,
host_id
))
if
_WRAP_INPUT_FN_INTO_WHILE_LOOP
:
run_infeed_loop_on_coordinator
=
False
enqueue_ops
.
append
(
_wrap_computation_in_while_loop
(
device
=
host_device
,
op_fn
=
enqueue_ops_fn
))
else
:
enqueue_ops
.
append
(
enqueue_ops_fn
())
# Infeed_queue_getter must be called after enqueue_ops_fn is called.
infeed_queues
.
append
(
captured_infeed_queue
.
get
())
elif
self
.
_ctx
.
is_input_broadcast_with_iterators
():
# Only calls input_fn in host 0.
host_device
=
tpu_host_placement_fn
(
host_id
=
0
)
enqueue_ops_fn
,
captured_infeed_queue
,
hooks
,
is_dataset
=
(
generate_broadcast_enqueue_ops_fn
(
self
.
_ctx
,
self
.
_input_fn
,
self
.
_inputs_structure_recorder
,
num_hosts
))
all_hooks
.
extend
(
hooks
)
if
is_dataset
:
run_infeed_loop_on_coordinator
=
False
wrap_fn
=
(
_wrap_computation_in_while_loop
if
self
.
_ctx
.
mode
!=
model_fn_lib
.
ModeKeys
.
PREDICT
else
_wrap_computation_in_while_loop_with_stopping_signals
)
enqueue_ops
.
append
(
wrap_fn
(
device
=
host_device
,
op_fn
=
enqueue_ops_fn
))
else
:
enqueue_ops
.
append
(
enqueue_ops_fn
())
infeed_queues
.
append
(
captured_infeed_queue
.
get
())
else
:
for
host_id
in
range
(
num_hosts
):
host_device
=
tpu_host_placement_fn
(
host_id
=
host_id
)
with
ops
.
device
(
host_device
):
with
ops
.
name_scope
(
'input_pipeline_task%d'
%
(
host_id
)):
if
self
.
_ctx
.
is_input_per_host_with_iterators
():
enqueue_ops_fn
,
captured_infeed_queue
,
hooks
,
is_dataset
=
(
generate_per_host_v2_enqueue_ops_fn_for_host
(
self
.
_ctx
,
self
.
_input_fn
,
self
.
_inputs_structure_recorder
,
host_device
,
host_id
))
else
:
enqueue_ops_fn
,
captured_infeed_queue
,
hooks
,
is_dataset
=
(
generate_per_host_enqueue_ops_fn_for_host
(
self
.
_ctx
,
self
.
_input_fn
,
self
.
_inputs_structure_recorder
,
self
.
_batch_axis
,
host_device
,
host_id
))
all_hooks
.
extend
(
hooks
)
# NOTE(xiejw): We dispatch here based on the return type of the
# users `input_fn`.
#
# 1. If input_fn returns a Dataset instance, we initialize the
# iterator outside of tf.while_loop, and call the iterator.get_next
# inside tf.while_loop. This should be always safe.
#
# 2. If input_fn returns (features, labels), it is too late to wrap
# them inside tf.while_loop, as resource initialization cannot be
# handled in TF control flow properly. In this case, we will use
# python loop to enqueue the data into TPU system. This may be
# slow compared to the previous case.
if
is_dataset
:
run_infeed_loop_on_coordinator
=
False
wrap_fn
=
(
_wrap_computation_in_while_loop
if
self
.
_ctx
.
mode
!=
model_fn_lib
.
ModeKeys
.
PREDICT
else
_wrap_computation_in_while_loop_with_stopping_signals
)
enqueue_ops
.
append
(
wrap_fn
(
device
=
host_device
,
op_fn
=
enqueue_ops_fn
))
else
:
enqueue_ops
.
append
(
enqueue_ops_fn
())
infeed_queues
.
append
(
captured_infeed_queue
.
get
())
# infeed_queue is used to generate dequeue ops. The only thing it uses for
# dequeue is dtypes and types. So, any one can be used. Here, grab the
# first one.
self
.
_infeed_queue
=
infeed_queues
[
0
]
return
enqueue_ops
,
all_hooks
,
run_infeed_loop_on_coordinator
def
_validate_input_pipeline
(
self
):
"""Validates the input pipeline.
Perform some sanity checks to log user friendly information. We should
error out to give users better error message. But, if
_WRAP_INPUT_FN_INTO_WHILE_LOOP is False (legacy behavior), we cannot break
user code, so, log a warning.
Raises:
RuntimeError: If the validation failed.
"""
if
ops
.
get_default_graph
().
get_collection
(
ops
.
GraphKeys
.
QUEUE_RUNNERS
):
err_msg
=
(
'Input pipeline contains one or more QueueRunners. '
'It could be slow and not scalable. Please consider '
'converting your input pipeline to use `tf.data` instead (see '
'https://www.tensorflow.org/guide/datasets for '
'instructions.'
)
if
_WRAP_INPUT_FN_INTO_WHILE_LOOP
:
raise
RuntimeError
(
err_msg
)
else
:
logging
.
warn
(
err_msg
)
class
_ModelFnWrapper
(
object
):
"""A `model_fn` wrapper.
This makes calling model_fn on CPU and TPU easier and more consistent and
performs necessary check and mutation required by TPU training and evaluation.
In addition, this wrapper manages converting the `model_fn` to a single TPU
train and eval step.
"""
def
__init__
(
self
,
model_fn
,
train_cache_fn
,
eval_cache_fn
,
config
,
params
,
ctx
):
self
.
_model_fn
=
model_fn
self
.
_train_cache_fn
=
train_cache_fn
self
.
_eval_cache_fn
=
eval_cache_fn
self
.
_config
=
config
self
.
_params
=
params
self
.
_ctx
=
ctx
def
call_without_tpu
(
self
,
features
,
labels
,
is_export_mode
):
return
self
.
_call_model_fn
(
features
,
labels
,
is_export_mode
=
is_export_mode
)
def
convert_to_single_tpu_train_step
(
self
,
dequeue_fn
):
"""Converts user provided model_fn` as a single train step on TPU.
The user provided `model_fn` takes input tuple
(features, labels) and produces the EstimatorSpec with train_op and loss for
train `mode`. This usually represents a single train computation on CPU.
For TPU training, a train (computation) step is first wrapped in a
tf.while_loop control flow to repeat for many times and then replicated to
all TPU shards. Besides the input should be taken from TPU infeed rather
than input pipeline (input_fn) directly. To fit TPU loop and replicate
pattern, the original train computation should be reformed, which is the
returned `train_step`.
Args:
dequeue_fn: The function to retrieve inputs, features and labels, from TPU
infeed dequeue channel.
Returns:
A tuple of train_fn, host_calls, and captured scaffold_fn. The train_fn
representing the train step for TPU.
"""
host_call
=
_OutfeedHostCall
(
self
.
_ctx
)
captured_scaffold_fn
=
_CapturedObject
()
captured_training_hooks
=
_CapturedObject
()
def
train_step
(
loss
,
*
cache
):
"""Training step function for use inside a while loop."""
if
not
self
.
_params
.
get
(
'track_mean'
,
False
):
del
loss
# unused; required in function signature.
inputs
=
dequeue_fn
()
features
,
labels
=
inputs
.
features_and_labels
()
# Consume the current cache
estimator_spec
=
self
.
_verify_estimator_spec
(
self
.
_call_model_fn
(
features
,
labels
,
cache
=
cache
))
# Retrieve the new returned cache
"""
`cache` consists of a list of tensors, potentially empty (of length 0)
"""
cache
=
estimator_spec
.
cache
new_loss
,
train_op
=
estimator_spec
.
loss
,
estimator_spec
.
train_op
if
isinstance
(
estimator_spec
,
model_fn_lib
.
_TPUEstimatorSpec
):
# pylint: disable=protected-access
captured_scaffold_fn
.
capture
(
estimator_spec
.
scaffold_fn
)
else
:
captured_scaffold_fn
.
capture
(
None
)
captured_training_hooks
.
capture
(
estimator_spec
.
training_hooks
)
# We must run train_op to update the variables prior to running the
# outfeed.
with
ops
.
control_dependencies
([
train_op
]):
host_call_outfeed_ops
=
[]
if
(
isinstance
(
estimator_spec
,
model_fn_lib
.
_TPUEstimatorSpec
)
# pylint: disable=protected-access
and
estimator_spec
.
host_call
is
not
None
):
host_call
.
record
({
'host_call'
:
estimator_spec
.
host_call
})
host_call_outfeed_ops
=
host_call
.
create_enqueue_op
()
with
ops
.
control_dependencies
(
host_call_outfeed_ops
):
if
self
.
_params
.
get
(
'track_mean'
,
False
):
loss
=
tensorflow
.
stop_gradient
(
loss
)
return
[
math_ops
.
add
(
loss
,
new_loss
)]
+
cache
else
:
return
[
array_ops
.
identity
(
new_loss
)]
+
cache
return
(
train_step
,
host_call
,
captured_scaffold_fn
,
captured_training_hooks
)
def
convert_to_single_tpu_eval_step
(
self
,
dequeue_fn
):
"""Converts user provided model_fn` as a single eval step on TPU.
Similar to training, the user provided `model_fn` takes input tuple
(features, labels) and produces the TPUEstimatorSpec with eval_metrics for
eval `mode`. This usually represents a single evaluation computation on CPU.
For TPU evaluation, a eval (computation) step is first wrapped in a
tf.while_loop control flow to repeat for many times and then replicated to
all TPU shards. Besides the input and output are slightly different. Input,
features and labels, should be taken from TPU infeed rather than input
pipeline (input_fn) directly. Output is managed in two stages. First, the
model outputs as the result of evaluation computation, usually model logits,
should be transferred from TPU system to CPU. Then, all model outputs are
concatenated first on CPU and sent to the metric_fn for metrics computation.
To fit TPU evaluation pattern, the original eval computation should be
reformed, which is the returned `eval_step`.
Args:
dequeue_fn: The function to retrieve inputs, features and labels, from TPU
infeed dequeue channel.
Returns:
A tuple of eval_fn, host_calls, and captured scaffold_fn. The eval_fn
representing the eval step for TPU.
"""
host_calls
=
_OutfeedHostCall
(
self
.
_ctx
)
captured_scaffold_fn
=
_CapturedObject
()
captured_eval_hooks
=
_CapturedObject
()
def
eval_step
(
total_loss
,
*
cache
):
"""Evaluation step function for use inside a while loop."""
inputs
=
dequeue_fn
()
features
,
labels
=
inputs
.
features_and_labels
()
# Consume the current cache
tpu_estimator_spec
=
self
.
_call_model_fn
(
features
,
labels
,
cache
=
cache
)
if
not
isinstance
(
tpu_estimator_spec
,
model_fn_lib
.
_TPUEstimatorSpec
):
# pylint: disable=protected-access
raise
RuntimeError
(
'estimator_spec used by TPU evaluation must have type'
'`TPUEstimatorSpec`. Got {}'
.
format
(
type
(
tpu_estimator_spec
)))
# Retrieve the new returned cache
cache
=
tpu_estimator_spec
.
cache
loss
=
tpu_estimator_spec
.
loss
captured_scaffold_fn
.
capture
(
tpu_estimator_spec
.
scaffold_fn
)
captured_eval_hooks
.
capture
(
tpu_estimator_spec
.
evaluation_hooks
)
to_record
=
{}
if
tpu_estimator_spec
.
eval_metrics
:
to_record
[
'eval_metrics'
]
=
tpu_estimator_spec
.
eval_metrics
if
tpu_estimator_spec
.
host_call
is
not
None
:
# We assume that evaluate won't update global step, so we don't wrap
# this host_call.
to_record
[
'host_call'
]
=
tpu_estimator_spec
.
host_call
host_calls
.
record
(
to_record
)
with
ops
.
control_dependencies
(
host_calls
.
create_enqueue_op
()):
return
[
math_ops
.
add
(
total_loss
,
loss
)]
+
cache
return
eval_step
,
host_calls
,
captured_scaffold_fn
,
captured_eval_hooks
def
convert_to_single_tpu_predict_step
(
self
,
dequeue_fn
):
"""Converts user provided model_fn` as a single predict step on TPU.
Args:
dequeue_fn: The function to retrieve inputs, features and labels, from TPU
infeed dequeue channel.
Returns:
A tuple of predict_fn, host_calls, and captured scaffold_fn. The
predict_fn representing the predict step for TPU.
"""
host_calls
=
_OutfeedHostCall
(
self
.
_ctx
)
captured_scaffold_fn
=
_CapturedObject
()
captured_predict_hooks
=
_CapturedObject
()
def
predict_step
(
unused_scalar_stopping_signal
):
"""Evaluation step function for use inside a while loop."""
inputs
=
dequeue_fn
()
features
,
labels
=
inputs
.
features_and_labels
()
stopping_signals
=
inputs
.
signals
()
assert
stopping_signals
is
not
None
,
(
'Internal Error: `signals` is missing.'
)
tpu_estimator_spec
=
self
.
_call_model_fn
(
features
,
labels
,
is_export_mode
=
False
)
if
not
isinstance
(
tpu_estimator_spec
,
model_fn_lib
.
_TPUEstimatorSpec
):
# pylint: disable=protected-access
raise
RuntimeError
(
'estimator_spec used by TPU prediction must have type'
'`TPUEstimatorSpec`. Got {}'
.
format
(
type
(
tpu_estimator_spec
)))
self
.
_verify_tpu_spec_predictions
(
tpu_estimator_spec
.
predictions
)
captured_scaffold_fn
.
capture
(
tpu_estimator_spec
.
scaffold_fn
)
captured_predict_hooks
.
capture
(
tpu_estimator_spec
.
prediction_hooks
)
to_record
=
{}
identity_fn
=
lambda
**
kwargs
:
kwargs
to_record
[
'predictions'
]
=
[
identity_fn
,
tpu_estimator_spec
.
predictions
]
to_record
[
'signals'
]
=
[
identity_fn
,
stopping_signals
]
if
tpu_estimator_spec
.
host_call
is
not
None
:
to_record
[
'host_call'
]
=
tpu_estimator_spec
.
host_call
host_calls
.
record
(
to_record
)
with
ops
.
control_dependencies
(
host_calls
.
create_enqueue_op
()):
return
_StopSignals
.
as_scalar_stopping_signal
(
stopping_signals
)
return
(
predict_step
,
host_calls
,
captured_scaffold_fn
,
captured_predict_hooks
)
def
_verify_tpu_spec_predictions
(
self
,
predictions
):
"""Validates TPUEstimatorSpec.predictions dict."""
# TODO(xiejw): Adds validation for prediction dictionrary.
# TODO(xiejw): Adds support for single tensor as predictions.
if
not
isinstance
(
predictions
,
dict
):
raise
TypeError
(
'TPUEstimatorSpec.predictions must be dict of Tensors.'
)
for
(
key
,
tensor
)
in
predictions
.
items
():
if
tensor
.
shape
[
0
].
value
is
None
:
raise
ValueError
(
'The tensor with key ({}) in TPUEstimatorSpec.predictions has '
'dynamic shape (should be static). Tensor: {}'
.
format
(
key
,
tensor
))
return
predictions
def
_validate_model_features_and_labels
(
self
,
features
,
labels
,
is_export_mode
):
"""Validates that the features and labels for the model function are valid.
A valid features/labels object is the one with:
- Type: Tensor or a dictionary of Tensors
- Static shape if is_export_mode is False.
Args:
features: the features that would be input to the model function.
labels: the labels that would be input to the model function.
is_export_mode: boolean value specifying if in export mode.
Raises:
TypeError: If features/labels are not of the correct type.
ValueError: If features/labels have dynamic shape.
"""
def
validate
(
obj
,
obj_name
):
"""Helper validate function."""
if
not
isinstance
(
obj
,
ops
.
Tensor
)
and
not
isinstance
(
obj
,
dict
):
raise
TypeError
(
'The {} to the model returned by input_fn must be either a Tensor '
'or a dictionary of Tensors. {}: {}'
.
format
(
obj_name
,
obj_name
,
obj
))
if
is_export_mode
or
self
.
_ctx
.
is_running_on_cpu
(
is_export_mode
):
return
if
isinstance
(
obj
,
ops
.
Tensor
):
if
not
obj
.
get_shape
().
is_fully_defined
():
raise
ValueError
(
'The {} to the model returned by input_fn must have static shape.'
' Tensor: {}'
.
format
(
obj_name
,
obj
))
else
:
for
(
key
,
value
)
in
obj
.
items
():
flattened_tensors
=
data_nest
.
flatten
(
value
)
for
tensor
in
flattened_tensors
:
if
not
tensor
.
get_shape
().
is_fully_defined
():
raise
ValueError
(
'The {} to the model returned by input_fn must have static '
'shape. Key:
\'
{}
\'
, Tensor: {}'
.
format
(
obj_name
,
key
,
tensor
))
validate
(
features
,
'features'
)
if
labels
is
not
None
:
validate
(
labels
,
'labels'
)
def
_call_model_fn
(
self
,
features
,
labels
,
cache
=
None
,
is_export_mode
=
False
):
"""Calls the model_fn with required parameters."""
self
.
_validate_model_features_and_labels
(
features
,
labels
,
is_export_mode
)
model_fn_args
=
function_utils
.
fn_args
(
self
.
_model_fn
)
kwargs
=
{}
# Makes deep copy with `config` and params` in case user mutates them.
config
=
copy
.
deepcopy
(
self
.
_config
)
params
=
copy
.
deepcopy
(
self
.
_params
)
if
'labels'
in
model_fn_args
:
kwargs
[
'labels'
]
=
labels
elif
labels
is
not
None
:
raise
ValueError
(
'model_fn does not take labels, but input_fn returns labels.'
)
if
'mode'
in
model_fn_args
:
kwargs
[
'mode'
]
=
self
.
_ctx
.
mode
if
'config'
in
model_fn_args
:
kwargs
[
'config'
]
=
config
if
'params'
in
model_fn_args
:
kwargs
[
'params'
]
=
params
if
cache
is
not
None
:
params
[
'cache'
]
=
cache
if
'params'
not
in
model_fn_args
:
raise
ValueError
(
'model_fn ({}) does not include params argument, '
'required by TPUEstimator to pass batch size as '
'params[
\'
batch_size
\'
]'
.
format
(
self
.
_model_fn
))
if
is_export_mode
:
batch_size_for_model_fn
=
None
else
:
batch_size_for_model_fn
=
self
.
_ctx
.
batch_size_for_model_fn
if
batch_size_for_model_fn
is
not
None
:
_add_item_to_params
(
params
,
_BATCH_SIZE_KEY
,
batch_size_for_model_fn
)
running_on_cpu
=
self
.
_ctx
.
is_running_on_cpu
(
is_export_mode
)
_add_item_to_params
(
params
,
_USE_TPU_KEY
,
not
running_on_cpu
)
if
not
running_on_cpu
:
user_context
=
tpu_context
.
TPUContext
(
internal_ctx
=
self
.
_ctx
,
call_from_input_fn
=
False
)
_add_item_to_params
(
params
,
_CTX_KEY
,
user_context
)
estimator_spec
=
self
.
_model_fn
(
features
=
features
,
**
kwargs
)
if
(
running_on_cpu
and
isinstance
(
estimator_spec
,
model_fn_lib
.
_TPUEstimatorSpec
)):
# pylint: disable=protected-access
# The estimator_spec will be passed to `Estimator` directly, which expects
# type `EstimatorSpec`.
return
estimator_spec
.
as_estimator_spec
()
else
:
return
estimator_spec
def
_verify_estimator_spec
(
self
,
estimator_spec
):
"""Validates the estimator_spec."""
if
isinstance
(
estimator_spec
,
model_fn_lib
.
_TPUEstimatorSpec
):
# pylint: disable=protected-access
return
estimator_spec
err_msg
=
'{} returned by EstimatorSpec is not supported in TPUEstimator.'
if
estimator_spec
.
training_chief_hooks
:
raise
ValueError
(
err_msg
.
format
(
'training_chief_hooks'
)
+
'If you want'
+
' to pass training hooks, please pass via training_hooks.'
)
if
estimator_spec
.
scaffold
:
logging
.
warning
(
'EstimatorSpec.Scaffold is ignored by TPU train/eval. '
'Please use TPUEstimatorSpec.'
)
return
estimator_spec
class
_OutfeedHostCall
(
object
):
"""Support for `eval_metrics` and `host_call` in TPUEstimatorSpec."""
def
__init__
(
self
,
ctx
):
self
.
_ctx
=
ctx
self
.
_names
=
[]
# All of these are dictionaries of lists keyed on the name.
self
.
_host_fns
=
{}
self
.
_tensor_keys
=
collections
.
defaultdict
(
list
)
self
.
_tensors
=
collections
.
defaultdict
(
list
)
self
.
_tensor_dtypes
=
collections
.
defaultdict
(
list
)
self
.
_tensor_shapes
=
collections
.
defaultdict
(
list
)
@
staticmethod
def
validate
(
host_calls
):
"""Validates the `eval_metrics` and `host_call` in `TPUEstimatorSpec`."""
for
name
,
host_call
in
host_calls
.
items
():
if
not
isinstance
(
host_call
,
(
tuple
,
list
)):
raise
ValueError
(
'{} should be tuple or list'
.
format
(
name
))
if
len
(
host_call
)
!=
2
:
raise
ValueError
(
'{} should have two elements.'
.
format
(
name
))
if
not
callable
(
host_call
[
0
]):
raise
TypeError
(
'{}[0] should be callable.'
.
format
(
name
))
if
not
isinstance
(
host_call
[
1
],
(
tuple
,
list
,
dict
)):
raise
ValueError
(
'{}[1] should be tuple or list, or dict.'
.
format
(
name
))
if
isinstance
(
host_call
[
1
],
(
tuple
,
list
)):
fullargspec
=
tf_inspect
.
getfullargspec
(
host_call
[
0
])
fn_args
=
function_utils
.
fn_args
(
host_call
[
0
])
# wrapped_hostcall_with_global_step uses varargs, so we allow that.
if
fullargspec
.
varargs
is
None
and
len
(
host_call
[
1
])
!=
len
(
fn_args
):
raise
RuntimeError
(
'In TPUEstimatorSpec.{}, length of tensors {} does not match '
'method args of the function, which takes {}.'
.
format
(
name
,
len
(
host_call
[
1
]),
len
(
fn_args
)))
@
staticmethod
def
create_cpu_hostcall
(
host_calls
):
"""Runs on the host_call on CPU instead of TPU when use_tpu=False."""
_OutfeedHostCall
.
validate
(
host_calls
)
ret
=
{}
for
name
,
host_call
in
host_calls
.
items
():
host_fn
,
tensors
=
host_call
if
isinstance
(
tensors
,
(
tuple
,
list
)):
ret
[
name
]
=
host_fn
(
*
tensors
)
else
:
# Must be dict.
try
:
ret
[
name
]
=
host_fn
(
**
tensors
)
except
TypeError
as
e
:
logging
.
warning
(
'Exception while calling %s: %s. It is likely the tensors '
'(%s[1]) do not match the '
'function
\'
s arguments'
,
name
,
e
,
name
)
raise
e
return
ret
def
record
(
self
,
host_calls
):
"""Records the host_call structure."""
for
name
,
host_call
in
host_calls
.
items
():
host_fn
,
tensor_list_or_dict
=
host_call
self
.
_names
.
append
(
name
)
self
.
_host_fns
[
name
]
=
host_fn
if
isinstance
(
tensor_list_or_dict
,
dict
):
for
(
key
,
tensor
)
in
six
.
iteritems
(
tensor_list_or_dict
):
self
.
_tensor_keys
[
name
].
append
(
key
)
self
.
_tensors
[
name
].
append
(
tensor
)
self
.
_tensor_dtypes
[
name
].
append
(
tensor
.
dtype
)
self
.
_tensor_shapes
[
name
].
append
(
tensor
.
shape
)
else
:
# List or tuple.
self
.
_tensor_keys
[
name
]
=
None
for
tensor
in
tensor_list_or_dict
:
self
.
_tensors
[
name
].
append
(
tensor
)
self
.
_tensor_dtypes
[
name
].
append
(
tensor
.
dtype
)
self
.
_tensor_shapes
[
name
].
append
(
tensor
.
shape
)
def
create_enqueue_op
(
self
):
"""Create the op to enqueue the recorded host_calls.
Returns:
A list of enqueue ops, which is empty if there are no host calls.
"""
if
not
self
.
_names
:
return
[]
tensors
=
[]
# TODO(jhseu): Consider deduping tensors.
for
name
in
self
.
_names
:
tensors
.
extend
(
self
.
_tensors
[
name
])
with
ops
.
device
(
tpu
.
core
(
0
)):
return
[
tpu_ops
.
outfeed_enqueue_tuple
(
tensors
)]
def
create_tpu_hostcall
(
self
):
"""Sends the tensors through outfeed and runs the host_fn on CPU.
The tensors are concatenated along dimension 0 to form a global tensor
across all shards. The concatenated function is passed to the host_fn and
executed on the first host.
Returns:
A dictionary mapping name to the return type of the host_call by that
name.
Raises:
RuntimeError: If outfeed tensor is scalar.
"""
if
not
self
.
_names
:
return
{}
ret
=
{}
# For each i, dequeue_ops[i] is a list containing the tensors from all
# shards. This list is concatenated later.
dequeue_ops
=
[]
tensor_dtypes
=
[]
tensor_shapes
=
[]
for
name
in
self
.
_names
:
for
_
in
self
.
_tensors
[
name
]:
dequeue_ops
.
append
([])
for
dtype
in
self
.
_tensor_dtypes
[
name
]:
tensor_dtypes
.
append
(
dtype
)
for
shape
in
self
.
_tensor_shapes
[
name
]:
tensor_shapes
.
append
(
shape
)
# Outfeed ops execute on each replica's first logical core. Note: we must
# constraint it such that we have at most one outfeed dequeue and enqueue
# per replica.
for
i
in
xrange
(
self
.
_ctx
.
num_replicas
):
host_device
,
ordinal_id
=
self
.
_ctx
.
device_for_replica
(
i
)
with
ops
.
device
(
host_device
):
outfeed_tensors
=
tpu_ops
.
outfeed_dequeue_tuple
(
dtypes
=
tensor_dtypes
,
shapes
=
tensor_shapes
,
device_ordinal
=
ordinal_id
)
for
j
,
item
in
enumerate
(
outfeed_tensors
):
dequeue_ops
[
j
].
append
(
item
)
# Deconstruct dequeue ops.
dequeue_ops_by_name
=
{}
pos
=
0
for
name
in
self
.
_names
:
dequeue_ops_by_name
[
name
]
=
dequeue_ops
[
pos
:
pos
+
len
(
self
.
_tensors
[
name
])]
pos
+=
len
(
self
.
_tensors
[
name
])
# It is assumed evaluation always happens on single host TPU system. So,
# place all ops on tpu host if possible.
#
# TODO(jhseu): Evaluate whether this is right for summaries.
with
ops
.
device
(
self
.
_ctx
.
tpu_host_placement_function
(
replica_id
=
0
)):
for
name
in
self
.
_names
:
dequeue_ops
=
dequeue_ops_by_name
[
name
]
for
i
,
item
in
enumerate
(
dequeue_ops
):
if
dequeue_ops
[
i
][
0
].
shape
.
ndims
==
0
:
raise
RuntimeError
(
'All tensors outfed from TPU should preserve batch size '
'dimension, but got scalar {}'
.
format
(
dequeue_ops
[
i
][
0
]))
# TODO(xiejw): Allow users to specify the axis for batch size
# dimension.
dequeue_ops
[
i
]
=
array_ops
.
concat
(
dequeue_ops
[
i
],
axis
=
0
)
if
self
.
_tensor_keys
[
name
]
is
not
None
:
# The user-provided eval_metrics[1] is a dict.
dequeue_ops
=
dict
(
zip
(
self
.
_tensor_keys
[
name
],
dequeue_ops
))
try
:
ret
[
name
]
=
self
.
_host_fns
[
name
](
**
dequeue_ops
)
except
TypeError
as
e
:
logging
.
warning
(
'Exception while calling %s: %s. It is likely the tensors '
'(%s[1]) do not match the '
'function
\'
s arguments'
,
name
,
e
,
name
)
raise
e
else
:
ret
[
name
]
=
self
.
_host_fns
[
name
](
*
dequeue_ops
)
return
ret
class
_OutfeedHostCallHook
(
session_run_hook
.
SessionRunHook
):
"""Hook to run host calls when use_tpu=False."""
def
__init__
(
self
,
tensors
):
self
.
_tensors
=
tensors
def
begin
(
self
):
# We duplicate this code from the TPUInfeedOutfeedSessionHook rather than
# create a separate hook to guarantee execution order, because summaries
# need to be initialized before the outfeed thread starts.
# TODO(jhseu): Make a wrapper hook instead?
self
.
_init_ops
=
contrib_summary
.
summary_writer_initializer_op
()
# Get all the writer resources from the initializer, so we know what to
# flush.
self
.
_finalize_ops
=
[]
for
op
in
self
.
_init_ops
:
self
.
_finalize_ops
.
append
(
contrib_summary
.
flush
(
writer
=
op
.
inputs
[
0
]))
def
after_create_session
(
self
,
session
,
coord
):
session
.
run
(
self
.
_init_ops
)
def
before_run
(
self
,
run_context
):
return
basic_session_run_hooks
.
SessionRunArgs
(
self
.
_tensors
)
def
end
(
self
,
session
):
session
.
run
(
self
.
_finalize_ops
)
class
ExamplesPerSecondHook
(
basic_session_run_hooks
.
StepCounterHook
):
"""Calculate and report global_step/sec and examples/sec during runtime."""
def
__init__
(
self
,
batch_size
,
every_n_steps
=
100
,
every_n_secs
=
None
,
output_dir
=
None
,
summary_writer
=
None
):
self
.
_batch_size
=
batch_size
super
(
ExamplesPerSecondHook
,
self
).
__init__
(
every_n_steps
=
every_n_steps
,
every_n_secs
=
every_n_secs
,
output_dir
=
output_dir
,
summary_writer
=
summary_writer
)
def
_log_and_record
(
self
,
elapsed_steps
,
elapsed_time
,
global_step
):
global_step_per_sec
=
elapsed_steps
/
elapsed_time
examples_per_sec
=
self
.
_batch_size
*
global_step_per_sec
if
self
.
_summary_writer
is
not
None
:
global_step_summary
=
Summary
(
value
=
[
Summary
.
Value
(
tag
=
'global_step/sec'
,
simple_value
=
global_step_per_sec
)
])
example_summary
=
Summary
(
value
=
[
Summary
.
Value
(
tag
=
'examples/sec'
,
simple_value
=
examples_per_sec
)
])
self
.
_summary_writer
.
add_summary
(
global_step_summary
,
global_step
)
self
.
_summary_writer
.
add_summary
(
example_summary
,
global_step
)
logging
.
info
(
'global_step/sec: %g'
,
global_step_per_sec
)
logging
.
info
(
'examples/sec: %g'
,
examples_per_sec
)
class
InstallSignalHandlerHook
(
session_run_hook
.
SessionRunHook
):
"""Change SIGINT (CTRL^C) handler to force quit the process.
The default behavior often results in hanging processes.
The original handler is restored after training/evaluation.
"""
def
__init__
(
self
):
self
.
_signal_fn
=
signal
.
getsignal
(
signal
.
SIGINT
)
def
before_run
(
self
,
run_context
):
signal
.
signal
(
signal
.
SIGINT
,
signal
.
SIG_DFL
)
def
end
(
self
,
session
):
signal
.
signal
(
signal
.
SIGINT
,
self
.
_signal_fn
)
class
TPUEstimator
(
estimator_lib
.
Estimator
):
"""Estimator with TPU support.
TPUEstimator also supports training on CPU and GPU. You don't need to define
a separate `tf.estimator.Estimator`.
TPUEstimator handles many of the details of running on TPU devices, such as
replicating inputs and models for each core, and returning to host
periodically to run hooks.
TPUEstimator transforms a global batch size in params to a per-shard batch
size when calling the `input_fn` and `model_fn`. Users should specify
global batch size in constructor, and then get the batch size for each shard
in `input_fn` and `model_fn` by `params['batch_size']`.
- For training, `model_fn` gets per-core batch size; `input_fn` may get
per-core or per-host batch size depending on `per_host_input_for_training`
in `TPUConfig` (See docstring for TPUConfig for details).
- For evaluation and prediction, `model_fn` gets per-core batch size and
`input_fn` get per-host batch size.
Evaluation
==========
`model_fn` should return `TPUEstimatorSpec`, which expects the `eval_metrics`
for TPU evaluation. However, if eval_on_tpu is False, `model_fn` must return
`EstimatorSpec` and the evaluation will execute on CPU or GPU; in this case
the following discussion on TPU evaluation does not apply.
`TPUEstimatorSpec.eval_metrics` is a tuple of `metric_fn` and `tensors`, where
`tensors` could be a list of `Tensor`s or dict of names to `Tensor`s. (See
`TPUEstimatorSpec` for details). `metric_fn` takes the `tensors` and returns
a dict from metric string name to the result of calling a metric function,
namely a `(metric_tensor, update_op)` tuple.
One can set `use_tpu` to `False` for testing. All training, evaluation, and
predict will be executed on CPU. `input_fn` and `model_fn` will receive
`train_batch_size` or `eval_batch_size` unmodified as `params['batch_size']`.
Current limitations:
--------------------
1. TPU evaluation only works on a single host (one TPU worker) except
BROADCAST mode.
2. `input_fn` for evaluation should **NOT** raise an end-of-input exception
(`OutOfRangeError` or `StopIteration`). And all evaluation steps and all
batches should have the same size.
Example (MNIST):
----------------
```
# The metric Fn which runs on CPU.
def metric_fn(labels, logits):
predictions = tf.argmax(logits, 1)
return {
'accuracy': tf.metrics.precision(
labels=labels, predictions=predictions),
}
# Your model Fn which runs on TPU (eval_metrics is list in this example)
def model_fn(features, labels, mode, config, params):
...
logits = ...
if mode = tf.estimator.ModeKeys.EVAL:
return tpu_estimator.TPUEstimatorSpec(
mode=mode,
loss=loss,
eval_metrics=(metric_fn, [labels, logits]))
# or specify the eval_metrics tensors as dict.
def model_fn(features, labels, mode, config, params):
...
final_layer_output = ...
if mode = tf.estimator.ModeKeys.EVAL:
return tpu_estimator.TPUEstimatorSpec(
mode=mode,
loss=loss,
eval_metrics=(metric_fn, {
'labels': labels,
'logits': final_layer_output,
}))
```
Prediction
==========
Prediction on TPU is an experimental feature to support large batch inference.
It is not designed for latency-critical system. In addition, due to some
usability issues, for prediction with small dataset, CPU `.predict`, i.e.,
creating a new `TPUEstimator` instance with `use_tpu=False`, might be more
convenient.
Note: In contrast to TPU training/evaluation, the `input_fn` for prediction
*should* raise an end-of-input exception (`OutOfRangeError` or
`StopIteration`), which serves as the stopping signal to `TPUEstimator`. To be
precise, the ops created by `input_fn` produce one batch of the data.
The `predict()` API processes one batch at a time. When reaching the end of
the data source, an end-of-input exception should be raised by one of these
operations. The user usually does not need to do this manually. As long as the
dataset is not repeated forever, the `tf.data` API will raise an end-of-input
exception automatically after the last batch has been produced.
Note: Estimator.predict returns a Python generator. Please consume all the
data from the generator so that TPUEstimator can shutdown the TPU system
properly for user.
Current limitations:
--------------------
1. TPU prediction only works on a single host (one TPU worker).
2. `input_fn` must return a `Dataset` instance rather than `features`. In
fact, .train() and .evaluate() also support Dataset as return value.
Example (MNIST):
----------------
```
height = 32
width = 32
total_examples = 100
def predict_input_fn(params):
batch_size = params['batch_size']
images = tf.random_uniform(
[total_examples, height, width, 3], minval=-1, maxval=1)
dataset = tf.data.Dataset.from_tensor_slices(images)
dataset = dataset.map(lambda images: {'image': images})
dataset = dataset.batch(batch_size)
return dataset
def model_fn(features, labels, params, mode):
# Generate predictions, called 'output', from features['image']
if mode == tf.estimator.ModeKeys.PREDICT:
return tf.contrib.tpu.TPUEstimatorSpec(
mode=mode,
predictions={
'predictions': output,
'is_padding': features['is_padding']
})
tpu_est = TPUEstimator(
model_fn=model_fn,
...,
predict_batch_size=16)
# Fully consume the generator so that TPUEstimator can shutdown the TPU
# system.
for item in tpu_est.predict(input_fn=input_fn):
# Filter out item if the `is_padding` is 1.
# Process the 'predictions'
```
Exporting
=========
`export_savedmodel` exports 2 metagraphs, one with `tag_constants.SERVING`,
and another with `tag_constants.SERVING` and `tag_constants.TPU`.
At serving time, these tags are used to select metagraph to load.
Before running the graph on TPU, TPU system needs to be initialized. If
TensorFlow Serving model-server is used, this is done automatically. If
not, please call `session.run(tpu.initialize_system())`.
`tpu.outside_compilation` can be used to wrap TPU incompatible ops in
`model_fn`.
Example:
----------------
```
def model_fn(features, labels, mode, config, params):
...
logits = ...
export_outputs = {
'logits': export_output_lib.PredictOutput(
{'logits': logits})
}
def host_call(logits):
class_ids = math_ops.argmax(logits)
classes = string_ops.as_string(class_ids)
export_outputs['classes'] =
export_output_lib.ClassificationOutput(classes=classes)
tpu.outside_compilation(host_call, logits)
...
```
"""
def
__init__
(
self
,
model_fn
=
None
,
train_cache_fn
=
None
,
eval_cache_fn
=
None
,
model_dir
=
None
,
config
=
None
,
params
=
None
,
use_tpu
=
True
,
train_batch_size
=
None
,
eval_batch_size
=
None
,
predict_batch_size
=
None
,
batch_axis
=
None
,
eval_on_tpu
=
True
,
export_to_tpu
=
True
,
warm_start_from
=
None
):
"""Constructs an `TPUEstimator` instance.
Args:
model_fn: Model function as required by `Estimator` which returns
EstimatorSpec or TPUEstimatorSpec. `training_hooks`, 'evaluation_hooks',
and `prediction_hooks` must not capure any TPU Tensor inside the model_fn.
model_dir: Directory to save model parameters, graph and etc. This can
also be used to load checkpoints from the directory into a estimator to
continue training a previously saved model. If `None`, the model_dir in
`config` will be used if set. If both are set, they must be same. If
both are `None`, a temporary directory will be used.
config: An `tpu_config.RunConfig` configuration object. Cannot be `None`.
params: An optional `dict` of hyper parameters that will be passed into
`input_fn` and `model_fn`. Keys are names of parameters, values are
basic python types. There are reserved keys for `TPUEstimator`,
including 'batch_size'.
use_tpu: A bool indicating whether TPU support is enabled. Currently,
- TPU training and evaluation respect this bit, but eval_on_tpu can
override execution of eval. See below.
- Predict still happens on CPU.
train_batch_size: An int representing the global training batch size.
TPUEstimator transforms this global batch size to a per-shard batch
size, as params['batch_size'], when calling `input_fn` and `model_fn`.
Cannot be `None` if `use_tpu` is `True`.
Must be divisible by total number of replicas.
eval_batch_size: An int representing evaluation batch size.
Must be divisible by total number of replicas.
predict_batch_size: An int representing the prediction batch size.
Must be divisible by total number of replicas.
batch_axis: A python tuple of int values describing how each tensor
produced by the Estimator `input_fn` should be split across the TPU
compute shards. For example, if your input_fn produced (images, labels)
where the images tensor is in `HWCN` format, your shard dimensions would
be [3, 0], where 3 corresponds to the `N` dimension of your images
Tensor, and 0 corresponds to the dimension along which to split the
labels to match up with the corresponding images. If None is supplied,
and per_host_input_for_training is True, batches will be sharded based
on the major dimension. If tpu_config.per_host_input_for_training is
False or `PER_HOST_V2`, batch_axis is ignored.
eval_on_tpu: If False, evaluation runs on CPU or GPU. In this case, the
model_fn must return `EstimatorSpec` when called with `mode` as `EVAL`.
export_to_tpu: If True, `export_savedmodel()` exports a metagraph for
serving on TPU besides the one on CPU.
warm_start_from: Optional string filepath to a checkpoint or SavedModel to
warm-start from, or a `tf.estimator.WarmStartSettings`
object to fully configure warm-starting. If the string
filepath is provided instead of a `WarmStartSettings`,
then all variables are warm-started, and it is assumed
that vocabularies and Tensor names are unchanged.
Raises:
ValueError: `params` has reserved keys already.
"""
if
config
is
None
or
not
isinstance
(
config
,
tpu_config
.
RunConfig
):
raise
ValueError
(
'`config` must be provided with type `tpu_config.RunConfig`'
)
if
params
is
not
None
and
any
(
k
in
params
for
k
in
_RESERVED_PARAMS_KEYS
):
raise
ValueError
(
'{} are reserved keys but existed in params {}.'
.
format
(
_RESERVED_PARAMS_KEYS
,
params
))
if
use_tpu
:
# Perform some very basic validations. More validations will be found in
# _InternalTPUContext.
if
train_batch_size
is
None
:
raise
ValueError
(
'`train_batch_size` cannot be `None`'
)
util_lib
.
check_positive_integer
(
train_batch_size
,
'train_batch_size'
)
if
(
config
.
tpu_config
.
per_host_input_for_training
is
tpu_config
.
InputPipelineConfig
.
PER_SHARD_V1
and
config
.
tpu_config
.
num_cores_per_replica
):
raise
ValueError
(
'Model parallelism only supports per host input for training. '
'Please adjust TPURunconfig.per_host_input_for_training.'
)
if
eval_batch_size
is
not
None
:
util_lib
.
check_positive_integer
(
eval_batch_size
,
'eval_batch_size'
)
if
predict_batch_size
is
not
None
:
util_lib
.
check_positive_integer
(
predict_batch_size
,
'predict_batch_size'
)
# Verifies the model_fn signature according to Estimator framework.
estimator_lib
.
_verify_model_fn_args
(
model_fn
,
params
)
# pylint: disable=protected-access
# We cannot store config and params in this constructor as parent
# constructor might change them, such as assigning a temp dir for
# config.model_dir.
model_function
=
self
.
_augment_model_fn
(
model_fn
,
train_cache_fn
,
eval_cache_fn
,
batch_axis
)
# Overwrite log_step_count_steps to disable TensorLoggingHook and
# StepCounterHook from being created in Estimator. TPUEstimator already
# added equivalent hooks in _augment_model_fn above.
self
.
_log_every_n_steps
=
config
.
log_step_count_steps
config
=
config
.
replace
(
log_step_count_steps
=
None
)
# Passing non-None params as wrapped model_fn has it.
params
=
params
or
{}
super
(
TPUEstimator
,
self
).
__init__
(
model_fn
=
model_function
,
model_dir
=
model_dir
,
config
=
config
,
params
=
params
,
warm_start_from
=
warm_start_from
)
self
.
_iterations_per_training_loop
=
(
self
.
_config
.
tpu_config
.
iterations_per_loop
)
# All properties passed to _InternalTPUContext are immutable.
# pylint: disable=protected-access
self
.
_ctx
=
tpu_context
.
_get_tpu_context
(
self
.
_config
,
train_batch_size
,
eval_batch_size
,
predict_batch_size
,
use_tpu
,
eval_on_tpu
)
self
.
_export_to_tpu
=
export_to_tpu
self
.
_is_input_fn_invoked
=
None
self
.
_rendezvous
=
{}
def
_add_meta_graph_for_mode
(
self
,
builder
,
input_receiver_fn_map
,
checkpoint_path
,
strip_default_attrs
,
save_variables
=
True
,
mode
=
model_fn_lib
.
ModeKeys
.
PREDICT
,
export_tags
=
None
,
check_variables
=
True
):
if
self
.
_export_to_tpu
and
mode
!=
model_fn_lib
.
ModeKeys
.
PREDICT
:
raise
NotImplementedError
(
'TPUEstimator only handles mode PREDICT for exporting '
'when `export_to_tpu` is `True`; '
'got {}.'
.
format
(
mode
))
(
super
(
TPUEstimator
,
self
).
_add_meta_graph_for_mode
(
builder
,
input_receiver_fn_map
,
checkpoint_path
,
strip_default_attrs
,
save_variables
,
mode
=
mode
,
export_tags
=
export_tags
,
check_variables
=
check_variables
))
if
self
.
_export_to_tpu
:
input_receiver_fn_map
=
{
_REWRITE_FOR_INFERENCE_MODE
:
input_receiver_fn_map
[
mode
]}
export_tags
=
[
tag_constants
.
SERVING
,
tag_constants
.
TPU
]
mode
=
_REWRITE_FOR_INFERENCE_MODE
# See b/110052256 for why `check_variables` is `False`.
(
super
(
TPUEstimator
,
self
).
_add_meta_graph_for_mode
(
builder
,
input_receiver_fn_map
,
checkpoint_path
,
strip_default_attrs
,
save_variables
=
False
,
mode
=
mode
,
export_tags
=
export_tags
,
check_variables
=
False
))
def
_call_model_fn
(
self
,
features
,
labels
,
mode
,
config
):
if
mode
==
_REWRITE_FOR_INFERENCE_MODE
:
return
self
.
_call_model_fn_for_inference
(
features
,
labels
,
mode
,
config
)
else
:
return
super
(
TPUEstimator
,
self
).
_call_model_fn
(
features
,
labels
,
mode
,
config
)
def
_call_model_fn_for_inference
(
self
,
features
,
labels
,
mode
,
config
):
"""Wraps `_call_model_fn` for `export_savedmodel`."""
if
mode
!=
_REWRITE_FOR_INFERENCE_MODE
:
raise
ValueError
(
'mode must be {}; '
'got {}.'
.
format
(
_REWRITE_FOR_INFERENCE_MODE
,
mode
))
capture
=
_CapturedObject
()
def
computation
():
"""Compute tpu tensors used in export_outputs.
Passed to rewrite_for_inference so that model_fn will be called under
the rewriting contexts. Only tpu tensors are returned, but export_outputs
and scaffold are captured.
Returns:
A list of Tensors used in export_outputs and not marked for
outside_compilation.
"""
# We should only call model fn once and it should be inside `computation`
# so that building the graph will happen under `rewrite_for_inference`.
mode
=
model_fn_lib
.
ModeKeys
.
PREDICT
estimator_spec
=
self
.
_call_model_fn
(
features
,
labels
,
mode
,
config
)
# We pick the TPU tensors out from `export_output` and later return them
# from `computation` for rewriting.
tensors_dict
=
collections
.
OrderedDict
(
(
k
,
_export_output_to_tensors
(
v
))
for
k
,
v
in
six
.
iteritems
(
estimator_spec
.
export_outputs
)
)
tensors
=
nest
.
flatten
(
tensors_dict
)
tpu_tensors
=
[
t
for
t
in
tensors
if
_is_tpu_tensor
(
t
)]
# We cannot return anything other than `tpu_tensors` here so we capture
# the rest for later use.
capture
.
capture
((
estimator_spec
,
tensors_dict
,
tensors
))
return
tpu_tensors
tpu_tensors_on_cpu
=
tpu
.
rewrite_for_inference
(
computation
)
estimator_spec
,
tensors_dict
,
tensors
=
capture
.
get
()
# Reconstruct `tensors`, but with `tpu_tensors` replaced with
# `tpu_tensors_on_cpu`.
new_tensors
=
[]
for
t
in
tensors
:
if
_is_tpu_tensor
(
t
):
new_tensors
.
append
(
tpu_tensors_on_cpu
.
pop
(
0
))
elif
t
is
None
:
new_tensors
.
append
(
None
)
else
:
# Only fetching `tpu_tensors_on_cpu` does not trigger
# TPU computation and blocks, so we add the control dependency here.
control_inputs
=
(
tpu_tensors_on_cpu
if
isinstance
(
tpu_tensors_on_cpu
,
(
list
,
tuple
))
else
(
tpu_tensors_on_cpu
,))
with
ops
.
control_dependencies
(
control_inputs
):
new_tensors
.
append
(
array_ops
.
identity
(
t
))
# Reconstruct `tensors_dict`.
new_tensors_dict
=
nest
.
pack_sequence_as
(
tensors_dict
,
new_tensors
)
# Reconstruct `export_outputs`.
export_outputs
=
estimator_spec
.
export_outputs
new_export_outputs
=
collections
.
OrderedDict
(
(
k
,
_clone_export_output_with_tensors
(
export_outputs
[
k
],
v
))
for
k
,
v
in
six
.
iteritems
(
new_tensors_dict
)
)
return
estimator_spec
.
_replace
(
export_outputs
=
new_export_outputs
)
def
_create_global_step
(
self
,
graph
):
"""Creates a global step suitable for TPUs.
Args:
graph: The graph in which to create the global step.
Returns:
A global step `Tensor`.
Raises:
ValueError: if the global step tensor is already defined.
"""
return
_create_global_step
(
graph
)
def
_convert_train_steps_to_hooks
(
self
,
steps
,
max_steps
):
with
self
.
_ctx
.
with_mode
(
model_fn_lib
.
ModeKeys
.
TRAIN
)
as
ctx
:
if
ctx
.
is_running_on_cpu
():
return
super
(
TPUEstimator
,
self
).
_convert_train_steps_to_hooks
(
steps
,
max_steps
)
# On TPU.
if
steps
is
None
and
max_steps
is
None
:
raise
ValueError
(
'For TPU training, one of `steps` or `max_steps` must be set. '
'Cannot be both `None`.'
)
# Estimator.train has explicit positiveness check.
if
steps
is
not
None
:
util_lib
.
check_positive_integer
(
steps
,
'Train steps'
)
if
max_steps
is
not
None
:
util_lib
.
check_positive_integer
(
max_steps
,
'Train max_steps'
)
return
[
_TPUStopAtStepHook
(
self
.
_iterations_per_training_loop
,
steps
,
max_steps
)
]
def
_convert_eval_steps_to_hooks
(
self
,
steps
):
with
self
.
_ctx
.
with_mode
(
model_fn_lib
.
ModeKeys
.
EVAL
)
as
ctx
:
if
ctx
.
is_running_on_cpu
():
return
super
(
TPUEstimator
,
self
).
_convert_eval_steps_to_hooks
(
steps
)
if
steps
is
None
:
raise
ValueError
(
'Evaluate `steps` must be set on TPU. Cannot be `None`.'
)
util_lib
.
check_positive_integer
(
steps
,
'Eval steps'
)
return
[
evaluation
.
_StopAfterNEvalsHook
(
# pylint: disable=protected-access
num_evals
=
steps
),
_SetEvalIterationsHook
(
steps
)
]
def
_call_input_fn
(
self
,
input_fn
,
mode
):
"""Calls the input function.
Args:
input_fn: The input function.
mode: ModeKeys
Returns:
Either features or (features, labels) where features and labels are:
features - `Tensor` or dictionary of string feature name to `Tensor`.
labels - `Tensor` or dictionary of `Tensor` with labels.
Raises:
ValueError: if input_fn takes invalid arguments or does not have `params`.
"""
input_fn_args
=
function_utils
.
fn_args
(
input_fn
)
config
=
self
.
config
# a deep copy.
kwargs
=
{}
if
'params'
in
input_fn_args
:
kwargs
[
'params'
]
=
self
.
params
# a deep copy.
else
:
raise
ValueError
(
'input_fn ({}) does not include params argument, '
'required by TPUEstimator to pass batch size as '
'params["batch_size"]'
.
format
(
input_fn
))
if
'config'
in
input_fn_args
:
kwargs
[
'config'
]
=
config
if
'mode'
in
input_fn_args
:
kwargs
[
'mode'
]
=
mode
# Records the fact input_fn has been invoked.
self
.
_is_input_fn_invoked
=
True
with
self
.
_ctx
.
with_mode
(
mode
)
as
ctx
:
# Setting the batch size in params first. This helps user to have same
# input_fn for use_tpu=True/False.
batch_size_for_input_fn
=
ctx
.
batch_size_for_input_fn
if
batch_size_for_input_fn
is
not
None
:
_add_item_to_params
(
kwargs
[
'params'
],
_BATCH_SIZE_KEY
,
batch_size_for_input_fn
)
# For export_savedmodel, input_fn is never passed to Estimator. So,
# `is_export_mode` must be False.
if
ctx
.
is_running_on_cpu
(
is_export_mode
=
False
):
with
ops
.
device
(
'/device:CPU:0'
):
return
input_fn
(
**
kwargs
)
# For TPU computation, input_fn should be invoked in a tf.while_loop for
# performance. While constructing the tf.while_loop, the structure of
# inputs returned by the `input_fn` needs to be recorded. The structure
# includes whether features or labels is dict or single Tensor, dict keys,
# tensor shapes, and dtypes. The recorded structure is used to create the
# infeed dequeue ops, which must be wrapped and passed as a Fn, called
# inside the TPU computation, as the TPU computation is wrapped inside a
# tf.while_loop also. So, we either pass input_fn to model_fn or pass
# dequeue_fn to model_fn. Here, `input_fn` is passed directly as
# `features` in `model_fn` signature.
def
_input_fn
(
ctx
):
_add_item_to_params
(
kwargs
[
'params'
],
_CTX_KEY
,
ctx
)
return
input_fn
(
**
kwargs
)
return
_input_fn
def
_validate_features_in_predict_input
(
self
,
result
):
"""Skip the validation.
For TPUEstimator, we do not need to check the result type. `_InputPipeline`
has stronger check. Parent class's check generates confusing warning msg.
Args:
result: `features` returned by input_fn.
"""
pass
def
train
(
self
,
input_fn
,
hooks
=
None
,
steps
=
None
,
max_steps
=
None
,
saving_listeners
=
None
):
rendezvous
=
error_handling
.
ErrorRendezvous
(
num_sources
=
3
)
self
.
_rendezvous
[
model_fn_lib
.
ModeKeys
.
TRAIN
]
=
rendezvous
try
:
return
super
(
TPUEstimator
,
self
).
train
(
input_fn
=
input_fn
,
hooks
=
hooks
,
steps
=
steps
,
max_steps
=
max_steps
,
saving_listeners
=
saving_listeners
)
except
Exception
:
# pylint: disable=broad-except
rendezvous
.
record_error
(
'training_loop'
,
sys
.
exc_info
())
finally
:
rendezvous
.
record_done
(
'training_loop'
)
rendezvous
.
raise_errors
()
def
evaluate
(
self
,
input_fn
,
steps
=
None
,
hooks
=
None
,
checkpoint_path
=
None
,
name
=
None
):
rendezvous
=
error_handling
.
ErrorRendezvous
(
num_sources
=
3
)
self
.
_rendezvous
[
model_fn_lib
.
ModeKeys
.
EVAL
]
=
rendezvous
try
:
return
super
(
TPUEstimator
,
self
).
evaluate
(
input_fn
,
steps
=
steps
,
hooks
=
hooks
,
checkpoint_path
=
checkpoint_path
,
name
=
name
)
except
Exception
:
# pylint: disable=broad-except
rendezvous
.
record_error
(
'evaluation_loop'
,
sys
.
exc_info
())
finally
:
rendezvous
.
record_done
(
'evaluation_loop'
)
rendezvous
.
raise_errors
()
def
predict
(
self
,
input_fn
,
predict_keys
=
None
,
hooks
=
None
,
checkpoint_path
=
None
,
yield_single_examples
=
True
):
rendezvous
=
error_handling
.
ErrorRendezvous
(
num_sources
=
3
)
self
.
_rendezvous
[
model_fn_lib
.
ModeKeys
.
PREDICT
]
=
rendezvous
try
:
for
result
in
super
(
TPUEstimator
,
self
).
predict
(
input_fn
=
input_fn
,
predict_keys
=
predict_keys
,
hooks
=
hooks
,
checkpoint_path
=
checkpoint_path
,
yield_single_examples
=
yield_single_examples
):
yield
result
except
Exception
:
# pylint: disable=broad-except
rendezvous
.
record_error
(
'prediction_loop'
,
sys
.
exc_info
())
finally
:
rendezvous
.
record_done
(
'prediction_loop'
)
rendezvous
.
raise_errors
()
rendezvous
.
record_done
(
'prediction_loop'
)
rendezvous
.
raise_errors
()
def
_augment_model_fn
(
self
,
model_fn
,
train_cache_fn
,
eval_cache_fn
,
batch_axis
):
"""Returns a new model_fn, which wraps the TPU support."""
def
_model_fn
(
features
,
labels
,
mode
,
config
,
params
):
"""A Estimator `model_fn` for TPUEstimator."""
with
self
.
_ctx
.
with_mode
(
mode
)
as
ctx
:
model_fn_wrapper
=
_ModelFnWrapper
(
model_fn
,
train_cache_fn
,
eval_cache_fn
,
config
,
params
,
ctx
)
# `input_fn` is called in `train()`, `evaluate()`, and `predict()`,
# but not in `export_savedmodel()`.
if
self
.
_is_input_fn_invoked
:
is_export_mode
=
False
else
:
is_export_mode
=
True
# Clear the bit.
self
.
_is_input_fn_invoked
=
None
# examples_hook is added to training_hooks for both CPU and TPU
# execution.
examples_hook
=
ExamplesPerSecondHook
(
ctx
.
global_batch_size
,
output_dir
=
self
.
model_dir
,
every_n_steps
=
self
.
_log_every_n_steps
)
if
ctx
.
is_running_on_cpu
(
is_export_mode
=
is_export_mode
):
logging
.
info
(
'Running %s on CPU'
,
mode
)
estimator_spec
=
model_fn_wrapper
.
call_without_tpu
(
features
,
labels
,
is_export_mode
=
is_export_mode
)
estimator_spec
=
estimator_spec
.
_replace
(
training_hooks
=
estimator_spec
.
training_hooks
+
(
examples_hook
,))
return
estimator_spec
assert
labels
is
None
,
'`labels` passed to `model_fn` must be `None`.'
# TPUEstimator._call_input_fn passes `input_fn` as features to here.
assert
callable
(
features
),
'`input_fn` is not callable.'
input_fn
=
features
input_holders
=
_InputPipeline
(
input_fn
,
batch_axis
,
ctx
)
enqueue_ops
,
dequeue_fn
,
input_hooks
,
run_infeed_loop_on_coordinator
=
(
input_holders
.
generate_infeed_enqueue_ops_and_dequeue_fn
())
graph
=
ops
.
get_default_graph
()
for
enqueue_op
in
enqueue_ops
:
if
isinstance
(
enqueue_op
,
list
):
graph
.
get_collection_ref
(
_TPU_ENQUEUE_OPS
).
extend
(
enqueue_op
)
else
:
graph
.
add_to_collection
(
_TPU_ENQUEUE_OPS
,
enqueue_op
)
if
mode
==
model_fn_lib
.
ModeKeys
.
TRAIN
:
loss
,
host_call
,
scaffold
,
training_hooks
=
(
_train_on_tpu_system
(
ctx
,
model_fn_wrapper
,
dequeue_fn
))
if
model_fn_wrapper
.
_params
.
get
(
'track_mean'
,
False
):
iterations_per_loop_var
=
_create_or_get_iterations_per_loop
()
loss
=
math_ops
.
div
(
loss
,
math_ops
.
cast
(
iterations_per_loop_var
,
dtype
=
loss
.
dtype
))
host_ops
=
host_call
.
create_tpu_hostcall
()
if
host_ops
is
None
:
host_ops
=
[]
shutdown_hooks
=
[]
shutdown_mode
=
os
.
environ
.
get
(
'TF_TPU_GRACEFUL_SHUTDOWN_MODE'
,
'shutdown_worker'
)
if
shutdown_mode
:
if
shutdown_mode
==
'shutdown_worker'
:
finalizer_hooks
=
[
session_support
.
ShutdownLameWorkers
(
timeout_ms
=
60
*
1000
),
]
elif
shutdown_mode
==
'shutdown_computation'
:
finalizer_hooks
=
[
session_support
.
RestartComputation
(
timeout_ms
=
60
*
1000
),
]
else
:
raise
ValueError
(
'Unknown TF_TPU_GRACEFUL_SHUTDOWN_MODE "%s"'
%
shutdown_mode
)
shutdown_hooks
.
append
(
session_support
.
GracefulShutdownHook
(
checkpoint_prefix
=
self
.
model_dir
+
'/model.ckpt'
,
on_shutdown_hooks
=
finalizer_hooks
))
with
ops
.
control_dependencies
([
loss
]):
global_step
=
array_ops
.
identity
(
training
.
get_global_step
())
hooks
=
input_hooks
+
shutdown_hooks
logging_hook_frequency
=
(
# Divide and round up
(
self
.
_log_every_n_steps
+
self
.
_config
.
tpu_config
.
iterations_per_loop
-
1
)
//
self
.
_config
.
tpu_config
.
iterations_per_loop
)
iterations_per_loop
=
array_ops
.
identity
(
_create_or_get_iterations_per_loop
())
hooks
.
extend
([
TPUInfeedOutfeedSessionHook
(
ctx
,
enqueue_ops
,
host_ops
,
run_infeed_loop_on_coordinator
=
(
run_infeed_loop_on_coordinator
),
rendezvous
=
self
.
_rendezvous
[
mode
],
),
InstallSignalHandlerHook
(),
training
.
LoggingTensorHook
(
{
'loss'
:
array_ops
.
identity
(
loss
),
'ppl'
:
tensorflow
.
exp
(
loss
),
'bpc'
:
loss
/
tensorflow
.
constant
(
math
.
log
(
2
)),
'#iter/loop'
:
iterations_per_loop
,
'global step'
:
global_step
,
},
every_n_iter
=
logging_hook_frequency
)
])
examples_hook
.
_set_steps_per_run
(
# pylint: disable=protected-access
self
.
_config
.
tpu_config
.
iterations_per_loop
)
hooks
.
append
(
examples_hook
)
if
training_hooks
:
hooks
.
extend
(
training_hooks
)
chief_hooks
=
[]
if
(
self
.
_config
.
save_checkpoints_secs
or
self
.
_config
.
save_checkpoints_steps
):
checkpoint_hook
=
training
.
CheckpointSaverHook
(
self
.
model_dir
,
save_secs
=
self
.
_config
.
save_checkpoints_secs
,
save_steps
=
self
.
_config
.
save_checkpoints_steps
,
scaffold
=
scaffold
)
checkpoint_hook
.
_set_steps_per_run
(
# pylint: disable=protected-access
self
.
_config
.
tpu_config
.
iterations_per_loop
)
chief_hooks
.
append
(
checkpoint_hook
)
summary
.
scalar
(
model_fn_lib
.
LOSS_METRIC_KEY
,
loss
)
with
ops
.
control_dependencies
([
loss
]):
update_ops
=
_sync_variables_ops
()
# Validate the TPU training graph to catch basic errors
_validate_tpu_training_graph
()
train_op
=
control_flow_ops
.
group
(
*
update_ops
)
graph
.
add_to_collection
(
_TPU_TRAIN_OP
,
train_op
)
return
model_fn_lib
.
EstimatorSpec
(
mode
,
loss
=
loss
,
training_chief_hooks
=
chief_hooks
,
training_hooks
=
hooks
,
train_op
=
train_op
,
scaffold
=
scaffold
)
if
mode
==
model_fn_lib
.
ModeKeys
.
EVAL
:
total_loss
,
host_calls
,
scaffold
,
eval_hooks
=
_eval_on_tpu_system
(
ctx
,
model_fn_wrapper
,
dequeue_fn
)
iterations_per_loop_var
=
_create_or_get_iterations_per_loop
()
mean_loss
=
math_ops
.
div
(
total_loss
,
math_ops
.
cast
(
iterations_per_loop_var
,
dtype
=
total_loss
.
dtype
))
# Creates a dummy metric update_op for all metrics. Estimator expects
# all metrics in eval_metric_ops have update_op and calls them one by
# one. The real metric update_ops are invoked in a separated thread.
# So, here give Estimator the dummy op for all metrics.
with
ops
.
control_dependencies
([
mean_loss
]):
# After TPU evaluation computation is done (the mean_loss tensor),
# reads all variables back from TPU and updates the eval step
# counter properly
internal_ops_to_run
=
_sync_variables_ops
()
internal_ops_to_run
.
append
(
_increase_eval_step_op
(
iterations_per_loop_var
))
with
ops
.
control_dependencies
(
internal_ops_to_run
):
dummy_update_op
=
control_flow_ops
.
no_op
()
host_call_ret
=
host_calls
.
create_tpu_hostcall
()
eval_metric_ops
=
{}
eval_update_ops
=
[]
for
k
,
v
in
host_call_ret
.
get
(
'eval_metrics'
,
{}).
items
():
eval_metric_ops
[
k
]
=
(
v
[
0
],
dummy_update_op
)
eval_update_ops
.
append
(
v
[
1
])
if
'host_call'
not
in
host_call_ret
:
host_ops
=
[]
else
:
host_ops
=
host_call_ret
[
'host_call'
]
hooks
=
[
TPUInfeedOutfeedSessionHook
(
ctx
,
enqueue_ops
,
eval_update_ops
+
host_ops
,
run_infeed_loop_on_coordinator
=
(
run_infeed_loop_on_coordinator
),
rendezvous
=
self
.
_rendezvous
[
mode
]),
]
+
input_hooks
if
eval_hooks
:
hooks
.
extend
(
eval_hooks
)
return
model_fn_lib
.
EstimatorSpec
(
mode
,
loss
=
mean_loss
,
evaluation_hooks
=
hooks
,
eval_metric_ops
=
eval_metric_ops
,
scaffold
=
scaffold
)
# Predict
assert
mode
==
model_fn_lib
.
ModeKeys
.
PREDICT
(
dummy_predict_op
,
host_calls
,
scaffold
,
prediction_hooks
)
=
_predict_on_tpu_system
(
ctx
,
model_fn_wrapper
,
dequeue_fn
)
with
ops
.
control_dependencies
([
dummy_predict_op
]):
internal_ops_to_run
=
_sync_variables_ops
()
with
ops
.
control_dependencies
(
internal_ops_to_run
):
dummy_predict_op
=
control_flow_ops
.
no_op
()
# In train and evaluation, the main TPU program is passed to monitored
# training session to run. Infeed enqueue and outfeed dequeue are
# executed in side threads. This is not the configuration for
# prediction mode.
#
# For prediction, the Estimator executes the EstimatorSpec.predictions
# directly and yield the element (via generator) to call site. So, the
# outfeed based prediction must be passed to MonitoredSession directly.
# Other parts of the TPU execution are organized as follows.
#
# 1. All outfeed based Tensors must be grouped with predictions Tensors
# to form a single invocation. This avoid the issue we might trigger
# multiple outfeeds incorrectly. To achieve this, `host_call` is
# placed in control_dependencies of `stopping_signals`, and
# `stopping_signals` is passed into _StoppingPredictHook, which sets
# the `stopping_signals` as SessionRunArgs. MonitoredSession merges
# all SessionRunArgs with the fetch in session.run together.
#
# 2. The TPU program (dummy_predict_op) and enqueue_ops (infeed Enqueue)
# are grouped together. They will be launched once and only once in
# side threads and they quit naturally according to the SAME stopping
# condition.
enqueue_ops
.
append
(
dummy_predict_op
)
host_call_ret
=
host_calls
.
create_tpu_hostcall
()
if
'host_call'
not
in
host_call_ret
:
host_ops
=
[]
else
:
host_ops
=
host_call_ret
[
'host_call'
]
predictions
=
host_call_ret
[
'predictions'
]
_verify_cross_hosts_transfer_size
(
predictions
,
message
=
(
'The estimated size for TPUEstimatorSpec.predictions is too '
'large.'
))
signals
=
host_call_ret
[
'signals'
]
with
ops
.
control_dependencies
(
host_ops
):
host_ops
=
[]
# Empty, we do do not need it anymore.
scalar_stopping_signal
=
_StopSignals
.
as_scalar_stopping_signal
(
signals
)
predictions
=
_PaddingSignals
.
slice_tensor_or_dict
(
predictions
,
signals
)
hooks
=
[
_StoppingPredictHook
(
scalar_stopping_signal
),
TPUInfeedOutfeedSessionHookForPrediction
(
ctx
,
enqueue_ops
,
host_ops
,
rendezvous
=
self
.
_rendezvous
[
mode
]),
]
+
input_hooks
if
prediction_hooks
:
hooks
.
extend
(
prediction_hooks
)
return
model_fn_lib
.
EstimatorSpec
(
mode
,
prediction_hooks
=
hooks
,
predictions
=
predictions
,
scaffold
=
scaffold
)
return
_model_fn
def
_is_tpu_tensor
(
tensor
):
if
not
isinstance
(
tensor
,
ops
.
Tensor
):
return
False
try
:
tensor
.
op
.
get_attr
(
tpu
.
_OUTSIDE_COMPILATION_ATTR
)
# pylint: disable=protected-access
except
ValueError
:
return
True
else
:
return
False
def
_export_output_to_tensors
(
export_output
):
"""Get a list of `Tensors` used in `export_output`.
Args:
export_output: an `ExportOutput` object such as `ClassificationOutput`,
`RegressionOutput`, or `PredictOutput`.
Returns:
a list of tensors used in export_output.
Raises:
ValueError: if `export_output` is not one of `ClassificationOutput`,
`RegressionOutput`, or `PredictOutput`.
"""
if
isinstance
(
export_output
,
export_output_lib
.
ClassificationOutput
):
return
[
export_output
.
scores
,
export_output
.
classes
]
elif
isinstance
(
export_output
,
export_output_lib
.
RegressionOutput
):
return
[
export_output
.
value
]
elif
isinstance
(
export_output
,
export_output_lib
.
PredictOutput
):
return
export_output
.
outputs
.
values
()
else
:
raise
ValueError
(
'`export_output` must be have type `ClassificationOutput`, '
'`RegressionOutput`, or `PredictOutput`; got {}.'
.
format
(
export_output
))
def
_clone_export_output_with_tensors
(
export_output
,
tensors
):
"""Clones `export_output` but with new `tensors`.
Args:
export_output: an `ExportOutput` object such as `ClassificationOutput`,
`RegressionOutput`, or `PredictOutput`.
tensors: a list of `Tensors` used to construct a new `export_output`.
Returns:
A dict similar to `export_output` but with `tensors`.
Raises:
ValueError: if `export_output` is not one of `ClassificationOutput`,
`RegressionOutput`, or `PredictOutput`.
"""
if
isinstance
(
export_output
,
export_output_lib
.
ClassificationOutput
):
if
len
(
tensors
)
!=
2
:
raise
ValueError
(
'tensors must be of length 2; '
'got {}.'
.
format
(
len
(
tensors
)))
return
export_output_lib
.
ClassificationOutput
(
*
tensors
)
elif
isinstance
(
export_output
,
export_output_lib
.
RegressionOutput
):
if
len
(
tensors
)
!=
1
:
raise
ValueError
(
'tensors must be of length 1; '
'got {}'
.
format
(
len
(
tensors
)))
return
export_output_lib
.
RegressionOutput
(
*
tensors
)
elif
isinstance
(
export_output
,
export_output_lib
.
PredictOutput
):
return
export_output_lib
.
PredictOutput
(
dict
(
zip
(
export_output
.
outputs
.
keys
(),
tensors
)))
else
:
raise
ValueError
(
'`export_output` must be have type `ClassificationOutput`, '
'`RegressionOutput`, or `PredictOutput`; got {}.'
.
format
(
export_output
))
def
_eval_on_tpu_system
(
ctx
,
model_fn_wrapper
,
dequeue_fn
):
"""Executes `model_fn_wrapper` multiple times on all TPU shards."""
iterations_per_loop_var
=
_create_or_get_iterations_per_loop
()
(
single_tpu_eval_step
,
host_calls
,
captured_scaffold_fn
,
captured_eval_hooks
)
=
model_fn_wrapper
.
convert_to_single_tpu_eval_step
(
dequeue_fn
)
def
multi_tpu_eval_steps_on_single_shard
():
loop_vars
=
[
_ZERO_LOSS
]
if
model_fn_wrapper
.
_eval_cache_fn
is
not
None
:
batch_size
=
ctx
.
global_batch_size
num_shards
=
ctx
.
_config
.
_tpu_config
.
num_shards
loop_vars
+=
model_fn_wrapper
.
_eval_cache_fn
(
batch_size
//
num_shards
)
return
training_loop
.
repeat
(
iterations_per_loop_var
,
single_tpu_eval_step
,
loop_vars
)
ret
=
tpu
.
shard
(
multi_tpu_eval_steps_on_single_shard
,
inputs
=
[],
num_shards
=
ctx
.
num_replicas
,
outputs_from_all_shards
=
False
,
device_assignment
=
ctx
.
device_assignment
)
loss
=
ret
[
0
]
scaffold
=
_get_scaffold
(
captured_scaffold_fn
)
return
loss
,
host_calls
,
scaffold
,
captured_eval_hooks
.
get
()
def
_train_on_tpu_system
(
ctx
,
model_fn_wrapper
,
dequeue_fn
):
"""Executes `model_fn_wrapper` multiple times on all TPU shards."""
iterations_per_loop_var
=
_create_or_get_iterations_per_loop
()
(
single_tpu_train_step
,
host_call
,
captured_scaffold_fn
,
captured_training_hooks
)
=
(
model_fn_wrapper
.
convert_to_single_tpu_train_step
(
dequeue_fn
))
def
multi_tpu_train_steps_on_single_shard
():
if
model_fn_wrapper
.
_params
.
get
(
'track_mean'
,
False
):
loop_vars
=
[
_ZERO_LOSS
]
else
:
loop_vars
=
[
_INITIAL_LOSS
]
if
model_fn_wrapper
.
_train_cache_fn
is
not
None
:
batch_size
=
ctx
.
global_batch_size
num_shards
=
ctx
.
_config
.
_tpu_config
.
num_shards
loop_vars
+=
model_fn_wrapper
.
_train_cache_fn
(
batch_size
//
num_shards
)
return
training_loop
.
repeat
(
iterations_per_loop_var
,
single_tpu_train_step
,
loop_vars
)
ret
=
tpu
.
shard
(
multi_tpu_train_steps_on_single_shard
,
inputs
=
[],
num_shards
=
ctx
.
num_replicas
,
outputs_from_all_shards
=
False
,
device_assignment
=
ctx
.
device_assignment
)
loss
=
ret
[
0
]
scaffold
=
_get_scaffold
(
captured_scaffold_fn
)
return
loss
,
host_call
,
scaffold
,
captured_training_hooks
.
get
()
def
_predict_on_tpu_system
(
ctx
,
model_fn_wrapper
,
dequeue_fn
):
"""Executes `model_fn_wrapper` multiple times on all TPU shards."""
(
single_tpu_predict_step
,
host_calls
,
captured_scaffold_fn
,
captured_predict_hooks
)
=
model_fn_wrapper
.
convert_to_single_tpu_predict_step
(
dequeue_fn
)
def
multi_tpu_predict_steps_on_single_shard
():
def
cond
(
scalar_stopping_signal
):
return
math_ops
.
logical_not
(
_StopSignals
.
should_stop
(
scalar_stopping_signal
))
inputs
=
[
_StopSignals
.
NON_STOPPING_SIGNAL
]
outputs
=
training_loop
.
while_loop
(
cond
,
single_tpu_predict_step
,
inputs
=
inputs
,
name
=
b
'loop'
)
return
outputs
(
dummy_predict_op
,)
=
tpu
.
shard
(
multi_tpu_predict_steps_on_single_shard
,
inputs
=
[],
num_shards
=
ctx
.
num_replicas
,
outputs_from_all_shards
=
False
,
device_assignment
=
ctx
.
device_assignment
)
scaffold
=
_get_scaffold
(
captured_scaffold_fn
)
return
dummy_predict_op
,
host_calls
,
scaffold
,
captured_predict_hooks
.
get
()
def
_wrap_computation_in_while_loop
(
device
,
op_fn
):
"""Wraps the ops generated by `op_fn` in tf.while_loop."""
def
computation
(
i
):
with
ops
.
control_dependencies
(
op_fn
()):
return
i
+
1
iterations_per_loop_var
=
_create_or_get_iterations_per_loop
()
# By setting parallel_iterations=1, the parallel execution in while_loop is
# basically turned off.
with
ops
.
device
(
device
):
iterations
=
array_ops
.
identity
(
iterations_per_loop_var
)
return
control_flow_ops
.
while_loop
(
lambda
i
:
i
<
iterations
,
computation
,
[
constant_op
.
constant
(
0
)],
parallel_iterations
=
1
)
def
_wrap_computation_in_while_loop_with_stopping_signals
(
device
,
op_fn
):
"""Wraps the ops generated by `op_fn` in tf.while_loop."""
def
cond
(
scalar_stopping_signal
):
return
math_ops
.
logical_not
(
_StopSignals
.
should_stop
(
scalar_stopping_signal
))
def
computation
(
unused_scalar_stopping_signal
):
return_value
=
op_fn
()
execute_ops
=
return_value
[
'ops'
]
signals
=
return_value
[
'signals'
]
with
ops
.
control_dependencies
(
execute_ops
):
return
_StopSignals
.
as_scalar_stopping_signal
(
signals
)
# By setting parallel_iterations=1, the parallel execution in while_loop is
# basically turned off.
with
ops
.
device
(
device
):
return
control_flow_ops
.
while_loop
(
cond
,
computation
,
[
_StopSignals
.
NON_STOPPING_SIGNAL
],
parallel_iterations
=
1
)
def
_validate_tpu_training_graph
():
"""Validate graph before running distributed training.
Raises:
ValueError: If the graph seems invalid for running on device
"""
operations
=
ops
.
get_default_graph
().
get_operations
()
# Check if there is atleast one CrossReplicaSum operation in the graph
# This should be introduced by using the CrossShardOptimizer wrapper
cross_replica_sum_ops
=
[
o
for
o
in
operations
if
o
.
type
==
_CROSS_REPLICA_SUM_OP
]
if
not
cross_replica_sum_ops
:
raise
ValueError
(
'CrossShardOptimizer must be used for model training on TPUs.'
)
class
_CapturedObject
(
object
):
"""A placeholder to capture an object.
This is useful when we need to capture a Python object in the Tensorflow
control flow body function and use it outside the control flow.
"""
def
__init__
(
self
):
self
.
_object
=
None
self
.
_captured
=
False
def
capture
(
self
,
o
):
if
self
.
_captured
:
raise
RuntimeError
(
'InternalError: Object can capture only once. Please file bug.'
)
self
.
_captured
=
True
self
.
_object
=
o
def
get
(
self
):
if
not
self
.
_captured
:
raise
RuntimeError
(
'InternalError: Object is not captured properly before `get`. '
'Please file bug.'
)
return
self
.
_object
def
_get_scaffold
(
captured_scaffold_fn
):
"""Retrieves the Scaffold from `captured_scaffold_fn`."""
with
_CapturingContext
(
message
=
'Inside scaffold_fn'
):
scaffold_fn
=
captured_scaffold_fn
.
get
()
if
scaffold_fn
:
scaffold
=
scaffold_fn
()
if
scaffold
is
None
:
raise
ValueError
(
'TPUEstimatorSpec.scaffold_fn returns None, which is not allowed'
)
else
:
scaffold
=
None
if
scaffold
:
wrapped_finalize
=
scaffold
.
finalize
def
_finalize
():
with
_CapturingContext
(
'Inside Scaffold.finalize'
):
wrapped_finalize
()
scaffold
.
finalize
=
_finalize
return
scaffold
class
_CapturingContext
(
control_flow_ops
.
ControlFlowContext
):
"""Tracks references to Tensors defined in TPU replication."""
def
__init__
(
self
,
message
):
control_flow_ops
.
ControlFlowContext
.
__init__
(
self
)
self
.
_message
=
message
def
AddOp
(
self
,
op
):
# pylint: disable=invalid-name
for
c
in
op
.
inputs
:
if
tpu
.
_TPU_REPLICATE_ATTR
in
c
.
op
.
node_def
.
attr
:
# pylint: disable=protected-access
raise
ValueError
(
'{}: Op {} depends on TPU computation {}, '
'which is not allowed.'
.
format
(
self
.
_message
,
op
,
c
))
def
to_control_flow_context_def
(
self
,
context_def
,
export_scope
=
None
):
# pylint: disable=useless-super-delegation
# NOTE(slebedev): the method is required by `ControlFlowContext`.
super
(
_CapturingContext
,
self
).
to_control_flow_context_def
(
context_def
,
export_scope
)
def
__enter__
(
self
):
# pylint: disable=protected-access
self
.
_g
=
ops
.
get_default_graph
()
self
.
_old
=
self
.
_g
.
_get_control_flow_context
()
self
.
_g
.
_set_control_flow_context
(
self
)
# pylint: enable=protected-access
def
__exit__
(
self
,
_
,
__
,
___
):
# pylint: disable=invalid-name
self
.
_g
.
_set_control_flow_context
(
self
.
_old
)
# pylint: disable=protected-access
class
_Inputs
(
object
):
"""A data structure representing the input_fn returned values.
This also supports the returned value from input_fn as `Dataset`.
"""
def
__init__
(
self
,
features
=
None
,
labels
=
None
,
dataset
=
None
,
signals
=
None
):
if
dataset
is
not
None
and
(
features
is
not
None
or
labels
is
not
None
or
signals
is
not
None
):
raise
RuntimeError
(
'Internal Error: Either (features and labels) or '
'dataset should be provided, not both. Please file '
'bug'
)
self
.
_features
=
features
self
.
_labels
=
labels
self
.
_signals
=
signals
self
.
_dataset
=
dataset
self
.
_iterator
=
None
@
staticmethod
def
from_input_fn
(
return_values
):
"""Returns an `_Inputs` instance according to `input_fn` return value."""
if
isinstance
(
return_values
,
dataset_ops
.
Dataset
):
dataset
=
return_values
return
_Inputs
(
dataset
=
dataset
)
features
,
labels
=
_Inputs
.
_parse_inputs
(
return_values
)
return
_Inputs
(
features
,
labels
)
@
staticmethod
def
_parse_inputs
(
return_values
):
if
isinstance
(
return_values
,
tuple
):
features
,
labels
=
return_values
else
:
features
,
labels
=
return_values
,
None
return
features
,
labels
@
property
def
is_dataset
(
self
):
"""Returns True if the return value from input_fn is Dataset."""
return
self
.
_dataset
is
not
None
def
dataset_initializer_hook
(
self
):
"""Returns a `SessionRunHook` to initialize this dataset.
This must be called before `features_and_labels`.
"""
iterator
=
self
.
_dataset
.
make_initializable_iterator
()
# pylint: disable=protected-access
hook
=
estimator_util
.
_DatasetInitializerHook
(
iterator
)
# pylint: enable=protected-access
self
.
_iterator
=
iterator
return
hook
def
features_and_labels
(
self
):
"""Gets `features` and `labels`."""
if
self
.
is_dataset
:
if
self
.
_iterator
is
None
:
raise
RuntimeError
(
'Internal error: Must call dataset_initializer_hook '
'before calling features_and_labels(). Please file '
'a bug!'
)
return
_Inputs
.
_parse_inputs
(
self
.
_iterator
.
get_next
())
return
(
self
.
_features
,
self
.
_labels
)
def
signals
(
self
):
return
self
.
_signals
@
property
def
dataset
(
self
):
return
self
.
_dataset
class
_InputsWithStoppingSignals
(
_Inputs
):
"""Inputs with `_StopSignals` inserted into the dataset."""
def
__init__
(
self
,
dataset
,
batch_size
,
add_padding
=
False
,
num_invocations_per_step
=
1
):
assert
dataset
is
not
None
user_provided_dataset
=
dataset
.
map
(
_InputsWithStoppingSignals
.
insert_stopping_signal
(
stop
=
False
,
batch_size
=
batch_size
,
add_padding
=
add_padding
))
if
num_invocations_per_step
==
1
:
final_batch_dataset
=
dataset
.
take
(
1
).
map
(
_InputsWithStoppingSignals
.
insert_stopping_signal
(
stop
=
True
,
batch_size
=
batch_size
,
add_padding
=
add_padding
))
else
:
# We append (2 * num_invocations_per_step - 1) batches for exhausting the
# user_provided_dataset and stop properly.
# For example, if num_invocations_per_step is 2, we append 3 additional
# padding batches: b1, b2, b3.
# If user_provided_dataset contains two batches: a1, a2
# Step 1: [a1, a2]
# Step 2: [b1, b2] -> STOP
# If user_provided_dataset contains three batches: a1, a2, a3.
# The training loops:
# Step 1: [a1, a2]
# Step 2: [a3, b1]
# Step 3: [b2, b3] -> STOP.
final_batch_dataset
=
dataset
.
take
(
1
).
map
(
_InputsWithStoppingSignals
.
insert_stopping_signal
(
stop
=
True
,
batch_size
=
batch_size
,
add_padding
=
add_padding
))
final_batch_dataset
=
final_batch_dataset
.
repeat
(
2
*
num_invocations_per_step
-
1
)
def
_set_mask
(
data_dict
):
signals
=
data_dict
[
'signals'
]
signals
[
'padding_mask'
]
=
array_ops
.
ones_like
(
signals
[
'padding_mask'
])
data_dict
[
'signals'
]
=
signals
return
data_dict
# Mask out the extra batch.
final_batch_dataset
=
final_batch_dataset
.
map
(
_set_mask
)
dataset
=
user_provided_dataset
.
concatenate
(
final_batch_dataset
).
prefetch
(
2
)
super
(
_InputsWithStoppingSignals
,
self
).
__init__
(
dataset
=
dataset
)
self
.
_current_inputs
=
None
def
features_and_labels
(
self
):
if
self
.
_current_inputs
is
not
None
:
raise
RuntimeError
(
'Internal Error: The previous inputs have not been properly '
'consumed. First call features_and_labels, then call signals.'
)
inputs_with_signals
=
self
.
_iterator
.
get_next
()
features
=
inputs_with_signals
[
'features'
]
labels
=
inputs_with_signals
.
get
(
'labels'
)
self
.
_current_inputs
=
inputs_with_signals
return
features
,
labels
def
signals
(
self
):
"""Returns the `Signals` from `_Inputs`."""
if
self
.
_current_inputs
is
None
:
raise
RuntimeError
(
'Internal Error: The current inputs have not been properly '
'generated. First call features_and_labels, then call signals.'
)
signals
=
self
.
_current_inputs
[
'signals'
]
self
.
_current_inputs
=
None
return
signals
@
staticmethod
def
insert_stopping_signal
(
stop
,
batch_size
,
add_padding
=
False
):
"""Inserts stopping_signal into dataset via _map_fn.
Here we change the data structure in the dataset, such that the return value
is a dictionary now and `features`, `labels`, and `signals` are three
distinguished keys in that dict. This provides a better structure, which
eases the process to decompose the inputs (see `features_and_labels`).
Args:
stop: bool, state of current stopping signals.
batch_size: int, batch size.
add_padding: bool, whether to pad the tensor to full batch size.
Returns:
A map_fn passed to dataset.map API.
"""
def
_map_fn
(
*
args
):
"""The map fn to insert signals."""
if
len
(
args
)
==
1
:
# Unpack the single Tensor/dict argument as features. This is required
# for the input_fn returns no labels.
args
=
args
[
0
]
features
,
labels
=
_Inputs
.
_parse_inputs
(
args
)
new_input_dict
=
{}
if
add_padding
:
padding_mask
,
features
,
labels
=
(
_PaddingSignals
.
pad_features_and_labels
(
features
,
labels
,
batch_size
))
new_input_dict
[
'features'
]
=
features
if
labels
is
not
None
:
new_input_dict
[
'labels'
]
=
labels
else
:
new_input_dict
[
'features'
]
=
features
if
labels
is
not
None
:
new_input_dict
[
'labels'
]
=
labels
padding_mask
=
None
new_input_dict
[
'signals'
]
=
_StopSignals
(
stop
=
stop
,
batch_size
=
batch_size
,
padding_mask
=
padding_mask
).
as_dict
()
return
new_input_dict
return
_map_fn
class
_StopSignals
(
object
):
"""Signals class holding all logic to handle TPU stopping condition."""
NON_STOPPING_SIGNAL
=
False
STOPPING_SIGNAL
=
True
def
__init__
(
self
,
stop
,
batch_size
,
padding_mask
=
None
):
self
.
_stop
=
stop
self
.
_batch_size
=
batch_size
self
.
_padding_mask
=
padding_mask
def
as_dict
(
self
):
"""Returns the signals as Python dict."""
shape
=
[
self
.
_batch_size
,
1
]
dtype
=
dtypes
.
bool
if
self
.
_stop
:
stopping
=
array_ops
.
ones
(
shape
=
shape
,
dtype
=
dtype
)
else
:
stopping
=
array_ops
.
zeros
(
shape
=
shape
,
dtype
=
dtype
)
signals
=
{
'stopping'
:
stopping
}
if
self
.
_padding_mask
is
not
None
:
signals
[
'padding_mask'
]
=
self
.
_padding_mask
return
signals
@
staticmethod
def
as_scalar_stopping_signal
(
signals
):
return
array_ops
.
identity
(
signals
[
'stopping'
][
0
][
0
])
@
staticmethod
def
should_stop
(
scalar_stopping_signal
):
"""Detects whether scalar_stopping_signal indicates stopping."""
if
isinstance
(
scalar_stopping_signal
,
ops
.
Tensor
):
# STOPPING_SIGNAL is a constant True. Here, the logical_and is just the TF
# way to express the bool check whether scalar_stopping_signal is True.
return
math_ops
.
logical_and
(
scalar_stopping_signal
,
_StopSignals
.
STOPPING_SIGNAL
)
else
:
# For non Tensor case, it is used in SessionRunHook. So, we cannot modify
# the graph anymore. Here, we use pure Python.
return
bool
(
scalar_stopping_signal
)
class
_PaddingSignals
(
object
):
"""Signals class holding all logic to handle padding."""
@
staticmethod
def
pad_features_and_labels
(
features
,
labels
,
batch_size
):
"""Pads out the batch dimension of features and labels."""
real_batch_size
=
array_ops
.
shape
(
_PaddingSignals
.
_find_any_tensor
(
features
))[
0
]
batch_size_tensor
=
constant_op
.
constant
(
batch_size
,
dtypes
.
int32
)
check_greater
=
check_ops
.
assert_greater_equal
(
batch_size_tensor
,
real_batch_size
,
data
=
(
batch_size_tensor
,
real_batch_size
),
message
=
'The real batch size should not be greater than batch_size.'
)
with
ops
.
control_dependencies
([
check_greater
]):
missing_count
=
batch_size_tensor
-
real_batch_size
def
pad_single_tensor
(
tensor
):
"""Pads out the batch dimension of a tensor to the complete batch_size."""
rank
=
len
(
tensor
.
shape
)
assert
rank
>
0
padding
=
array_ops
.
stack
([[
0
,
missing_count
]]
+
[[
0
,
0
]]
*
(
rank
-
1
))
padded_shape
=
(
batch_size
,)
+
tuple
(
tensor
.
shape
[
1
:])
padded_tensor
=
array_ops
.
pad
(
tensor
,
padding
)
padded_tensor
.
set_shape
(
padded_shape
)
return
padded_tensor
def
nest_pad
(
tensor_or_dict
):
return
nest
.
map_structure
(
pad_single_tensor
,
tensor_or_dict
)
features
=
nest_pad
(
features
)
if
labels
is
not
None
:
labels
=
nest_pad
(
labels
)
padding_mask
=
_PaddingSignals
.
_padding_mask
(
real_batch_size
,
missing_count
,
batch_size
)
return
padding_mask
,
features
,
labels
@
staticmethod
def
slice_tensor_or_dict
(
tensor_or_dict
,
signals
):
"""Slice the real Tensors according to padding mask in signals."""
padding_mask
=
signals
[
'padding_mask'
]
batch_size
=
array_ops
.
shape
(
padding_mask
)[
0
]
def
verify_batch_size
(
tensor
):
check_batch_size
=
math_ops
.
equal
(
batch_size
,
tensor
.
shape
[
0
])
with
ops
.
control_dependencies
([
check_batch_size
]):
return
array_ops
.
identity
(
tensor
)
def
slice_single_tensor
(
tensor
):
rank
=
len
(
tensor
.
shape
)
assert
rank
>
0
real_batch_size
=
batch_size
-
math_ops
.
reduce_sum
(
padding_mask
)
return
verify_batch_size
(
tensor
)[
0
:
real_batch_size
]
# As we split the Tensors to all TPU cores and concat them back, it is
# important to ensure the real data is placed before padded ones, i.e.,
# order is preserved. By that, the sliced padding mask should have all 0's.
# If this assertion failed, # the slice logic here would not hold.
sliced_padding_mask
=
slice_single_tensor
(
padding_mask
)
assert_padding_mask
=
math_ops
.
equal
(
math_ops
.
reduce_sum
(
sliced_padding_mask
),
0
)
with
ops
.
control_dependencies
([
assert_padding_mask
]):
should_stop
=
_StopSignals
.
should_stop
(
_StopSignals
.
as_scalar_stopping_signal
(
signals
))
is_full_batch
=
math_ops
.
equal
(
math_ops
.
reduce_sum
(
padding_mask
),
0
)
def
slice_fn
(
tensor
):
# If the current batch is full batch or part of stopping signals, we do
# not need to slice to save performance.
return
control_flow_ops
.
cond
(
math_ops
.
logical_or
(
should_stop
,
is_full_batch
),
(
lambda
:
verify_batch_size
(
tensor
)),
(
lambda
:
slice_single_tensor
(
tensor
)))
return
nest
.
map_structure
(
slice_fn
,
tensor_or_dict
)
@
staticmethod
def
_find_any_tensor
(
batch_features
):
tensors
=
[
x
for
x
in
nest
.
flatten
(
batch_features
)
if
isinstance
(
x
,
ops
.
Tensor
)]
if
not
tensors
:
raise
ValueError
(
'Cannot find any Tensor in features dict.'
)
return
tensors
[
0
]
@
staticmethod
def
_padding_mask
(
real_batch_size
,
missing_count
,
batch_size
):
padding_mask
=
array_ops
.
concat
(
[
array_ops
.
zeros
((
real_batch_size
,),
dtype
=
dtypes
.
int32
),
array_ops
.
ones
((
missing_count
,),
dtype
=
dtypes
.
int32
)
],
axis
=
0
)
padding_mask
.
set_shape
((
batch_size
,))
return
padding_mask
def
_verify_cross_hosts_transfer_size
(
tensor_dict
,
message
):
total_size
=
0
tensor_structure
=
{}
for
key
,
tensor
in
tensor_dict
.
items
():
shape
=
tensor
.
shape
size
=
np
.
product
(
shape
)
*
tensor
.
dtype
.
size
tensor_structure
[
key
]
=
shape
total_size
+=
size
if
total_size
>=
_ONE_GIGABYTE
:
raise
ValueError
(
'{} The transfer size is larger than the protobuf limit. Please '
'consider to use Tensors with smaller shapes or reduce batch '
'size. Given:
\n
'
'{}'
.
format
(
message
,
'
\n
'
.
join
([
' -- Key: {}, Shape: {}'
.
format
(
k
,
v
)
for
k
,
v
in
tensor_structure
.
items
()])))
def
_add_item_to_params
(
params
,
key
,
value
):
"""Adds a new item into `params`."""
if
isinstance
(
params
,
hparam
.
HParams
):
# For HParams, we need to use special API.
if
key
in
params
:
params
.
set_hparam
(
key
,
value
)
else
:
params
.
add_hparam
(
key
,
value
)
else
:
# Now params is Python dict.
params
[
key
]
=
value
def
export_estimator_savedmodel
(
estimator
,
export_dir_base
,
serving_input_receiver_fn
,
assets_extra
=
None
,
as_text
=
False
,
checkpoint_path
=
None
,
strip_default_attrs
=
False
):
"""Export `Estimator` trained model for TPU inference.
Args:
estimator: `Estimator` with which model has been trained.
export_dir_base: A string containing a directory in which to create
timestamped subdirectories containing exported SavedModels.
serving_input_receiver_fn: A function that takes no argument and
returns a `ServingInputReceiver` or `TensorServingInputReceiver`.
assets_extra: A dict specifying how to populate the assets.extra directory
within the exported SavedModel, or `None` if no extra assets are needed.
as_text: whether to write the SavedModel proto in text format.
checkpoint_path: The checkpoint path to export. If `None` (the default),
the most recent checkpoint found within the model directory is chosen.
strip_default_attrs: Boolean. If `True`, default-valued attributes will be
removed from the NodeDefs.
Returns:
The string path to the exported directory.
"""
# `TPUEstimator` requires `tpu_config.RunConfig`, so we cannot use
# `estimator.config`.
config
=
tpu_config
.
RunConfig
(
model_dir
=
estimator
.
model_dir
)
est
=
TPUEstimator
(
estimator
.
_model_fn
,
# pylint: disable=protected-access
config
=
config
,
params
=
estimator
.
params
,
use_tpu
=
True
,
train_batch_size
=
2048
,
# Does not matter.
eval_batch_size
=
2048
,
# Does not matter.
)
return
est
.
export_savedmodel
(
export_dir_base
,
serving_input_receiver_fn
,
assets_extra
,
as_text
,
checkpoint_path
,
strip_default_attrs
)
TensorFlow/NLP/transformer-xl-master/train.py
0 → 100644
View file @
cb8dde1c
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
math
import
time
from
absl
import
flags
import
absl.logging
as
_logging
# pylint: disable=unused-import
from
six.moves
import
xrange
# pylint: disable=redefined-builtin
import
tensorflow
as
tf
from
tensorflow.gfile
import
Exists
as
exists
import
model
import
data_utils
import
tpu_estimator
import
numpy
as
np
from
time
import
sleep
# TPU parameters
flags
.
DEFINE_string
(
"master"
,
default
=
None
,
help
=
"master"
)
flags
.
DEFINE_string
(
"tpu"
,
default
=
None
,
help
=
"The Cloud TPU to use for training. This should be either the name "
"used when creating the Cloud TPU, or a grpc://ip.address.of.tpu:8470 url."
)
flags
.
DEFINE_string
(
"gcp_project"
,
default
=
None
,
help
=
"Project name for the Cloud TPU-enabled project. If not specified, "
"we will attempt to automatically detect the GCE project from metadata."
)
flags
.
DEFINE_string
(
"tpu_zone"
,
default
=
None
,
help
=
"GCE zone where the Cloud TPU is located in. If not specified, we "
"will attempt to automatically detect the GCE project from metadata."
)
flags
.
DEFINE_bool
(
"use_tpu"
,
default
=
True
,
help
=
"Use TPUs rather than plain CPUs."
)
flags
.
DEFINE_integer
(
"num_hosts"
,
default
=
1
,
help
=
"number of TPU hosts"
)
flags
.
DEFINE_integer
(
"num_core_per_host"
,
default
=
8
,
help
=
"number of cores per host"
)
# Experiment (data/checkpoint/directory) parameters
flags
.
DEFINE_string
(
"data_dir"
,
default
=
""
,
help
=
"Path to tf-records directory."
)
flags
.
DEFINE_string
(
"record_info_dir"
,
default
=
""
,
help
=
"Path to local directory containing filenames.txt."
)
flags
.
DEFINE_string
(
"corpus_info_path"
,
default
=
""
,
help
=
"Path to corpus-info.json file."
)
flags
.
DEFINE_string
(
"model_dir"
,
default
=
None
,
help
=
"Estimator model_dir."
)
flags
.
DEFINE_bool
(
"do_eval"
,
default
=
False
,
help
=
"Whether to run eval on the dev set."
)
flags
.
DEFINE_bool
(
"track_mean"
,
default
=
True
,
help
=
"Trace mean loss during training."
)
flags
.
DEFINE_string
(
"eval_ckpt_path"
,
None
,
help
=
"Checkpoint path for evaluation."
"If set, model_dir will be ignored."
"If unset, will use the latest ckpt in model_dir."
)
flags
.
DEFINE_string
(
"warm_start_path"
,
None
,
help
=
"Checkpoint path for warm start."
"If set, will clear Adam states."
"Note that the new model_dir should be different"
" from warm_start_path."
)
# Optimization paramenters
flags
.
DEFINE_float
(
"learning_rate"
,
default
=
2.5e-4
,
help
=
"Maximum learning rate."
)
flags
.
DEFINE_float
(
"clip"
,
default
=
0.25
,
help
=
"Gradient clipping value."
)
# for cosine decay
flags
.
DEFINE_float
(
"min_lr_ratio"
,
default
=
0.01
,
help
=
"Minimum ratio learning rate."
)
flags
.
DEFINE_integer
(
"warmup_steps"
,
default
=
0
,
help
=
"Number of steps for linear lr warmup."
)
# Training parameters
flags
.
DEFINE_integer
(
"train_batch_size"
,
default
=
60
,
help
=
"Size of train batch."
)
flags
.
DEFINE_integer
(
"eval_batch_size"
,
default
=
60
,
help
=
"Size of valid batch."
)
flags
.
DEFINE_integer
(
"train_steps"
,
default
=
100000
,
help
=
"Total number of training steps."
)
flags
.
DEFINE_integer
(
"iterations"
,
default
=
500
,
help
=
"Number of iterations per repeat loop."
)
flags
.
DEFINE_integer
(
"save_steps"
,
default
=
10000
,
help
=
"number of steps for model checkpointing."
)
# Evaluation parameters
flags
.
DEFINE_integer
(
"max_eval_batch"
,
default
=-
1
,
help
=
"Set -1 to turn off. Only used in test mode."
)
flags
.
DEFINE_bool
(
"do_eval_only"
,
default
=
False
,
help
=
"Run evaluation only."
)
flags
.
DEFINE_integer
(
"start_eval_steps"
,
default
=
10000
,
help
=
"Which checkpoint to start with in `do_eval_only` mode."
)
flags
.
DEFINE_string
(
"eval_split"
,
"valid"
,
help
=
"Which data split to evaluate."
)
# Model paramenters
flags
.
DEFINE_integer
(
"tgt_len"
,
default
=
70
,
help
=
"Number of steps to predict"
)
flags
.
DEFINE_integer
(
"mem_len"
,
default
=
70
,
help
=
"Number of steps to cache"
)
flags
.
DEFINE_bool
(
"same_length"
,
default
=
False
,
help
=
"Same length attention"
)
flags
.
DEFINE_integer
(
"clamp_len"
,
default
=-
1
,
help
=
"Clamp length"
)
flags
.
DEFINE_integer
(
"n_layer"
,
default
=
6
,
help
=
"Number of layers."
)
flags
.
DEFINE_integer
(
"d_model"
,
default
=
500
,
help
=
"Dimension of the model."
)
flags
.
DEFINE_integer
(
"d_embed"
,
default
=
500
,
help
=
"Dimension of the embeddings."
)
flags
.
DEFINE_integer
(
"n_head"
,
default
=
10
,
help
=
"Number of attention heads."
)
flags
.
DEFINE_integer
(
"d_head"
,
default
=
50
,
help
=
"Dimension of each attention head."
)
flags
.
DEFINE_integer
(
"d_inner"
,
default
=
1000
,
help
=
"Dimension of inner hidden size in positionwise feed-forward."
)
flags
.
DEFINE_float
(
"dropout"
,
default
=
0.1
,
help
=
"Dropout rate."
)
flags
.
DEFINE_float
(
"dropatt"
,
default
=
0.1
,
help
=
"Attention dropout rate."
)
flags
.
DEFINE_bool
(
"untie_r"
,
default
=
False
,
help
=
"untie r_w_bias and r_r_bias"
)
# Adaptive Softmax / Embedding
flags
.
DEFINE_bool
(
"tie_weight"
,
default
=
True
,
help
=
"Tie embedding and softmax weight."
)
flags
.
DEFINE_integer
(
"div_val"
,
default
=
1
,
help
=
"Divide the embedding size by this val for each bin"
)
flags
.
DEFINE_bool
(
"proj_share_all_but_first"
,
default
=
False
,
help
=
"True to share all but first projs, False not to share."
)
flags
.
DEFINE_bool
(
"proj_same_dim"
,
default
=
True
,
help
=
"Project the bin with the same dimension."
)
# Parameter initialization
flags
.
DEFINE_enum
(
"init"
,
default
=
"normal"
,
enum_values
=
[
"normal"
,
"uniform"
],
help
=
"Initialization method."
)
flags
.
DEFINE_float
(
"init_std"
,
default
=
0.02
,
help
=
"Initialization std when init is normal."
)
flags
.
DEFINE_float
(
"proj_init_std"
,
default
=
0.01
,
help
=
"Initialization std for embedding projection."
)
flags
.
DEFINE_float
(
"init_range"
,
default
=
0.1
,
help
=
"Initialization std when init is uniform."
)
FLAGS
=
flags
.
FLAGS
def
metric_fn
(
loss
):
"""Evaluation metric Fn which runs on CPU."""
perplexity
=
tf
.
exp
(
tf
.
reduce_mean
(
loss
))
bpc
=
tf
.
reduce_mean
(
loss
)
/
tf
.
constant
(
math
.
log
(
2
))
return
{
"perplexity"
:
tf
.
metrics
.
mean
(
perplexity
),
"bpc"
:
tf
.
metrics
.
mean
(
bpc
),
}
def
get_model_fn
(
n_token
,
cutoffs
,
train_bin_sizes
,
eval_bin_sizes
):
def
model_fn
(
features
,
labels
,
mode
,
params
):
is_training
=
(
mode
==
tf
.
estimator
.
ModeKeys
.
TRAIN
)
batch_size
=
params
[
"batch_size"
]
mems
=
params
[
"cache"
]
inp
=
tf
.
transpose
(
features
[
"inputs"
],
[
1
,
0
])
tgt
=
tf
.
transpose
(
features
[
"labels"
],
[
1
,
0
])
bin_sizes
=
train_bin_sizes
if
is_training
else
eval_bin_sizes
if
bin_sizes
:
inp_perms
=
[
tf
.
transpose
(
features
[
"inp_mask"
],
[
1
,
0
])]
tgt_perms
=
[
tf
.
transpose
(
features
[
"tgt_mask"
],
[
1
,
0
])]
head_tgt
=
tf
.
transpose
(
features
[
"head_labels"
],
[
1
,
0
])
for
b
in
range
(
len
(
bin_sizes
)):
inp_perm
=
tf
.
transpose
(
features
[
"inp_perm_{}"
.
format
(
b
)],
[
1
,
0
,
2
])
tgt_perm
=
tf
.
transpose
(
features
[
"tgt_perm_{}"
.
format
(
b
)],
[
1
,
0
,
2
])
inp_perms
.
append
(
inp_perm
)
tgt_perms
.
append
(
tgt_perm
)
else
:
inp_perms
,
tgt_perms
,
head_tgt
=
None
,
None
,
None
if
FLAGS
.
init
==
"uniform"
:
initializer
=
tf
.
initializers
.
random_uniform
(
minval
=-
FLAGS
.
init_range
,
maxval
=
FLAGS
.
init_range
,
seed
=
None
)
elif
FLAGS
.
init
==
"normal"
:
initializer
=
tf
.
initializers
.
random_normal
(
stddev
=
FLAGS
.
init_std
,
seed
=
None
)
proj_initializer
=
tf
.
initializers
.
random_normal
(
stddev
=
FLAGS
.
proj_init_std
,
seed
=
None
)
tie_projs
=
[
False
for
_
in
range
(
len
(
cutoffs
)
+
1
)]
if
FLAGS
.
proj_share_all_but_first
:
for
i
in
range
(
1
,
len
(
tie_projs
)):
tie_projs
[
i
]
=
True
tf
.
logging
.
info
(
"Vocab size : {}"
.
format
(
n_token
))
tf
.
logging
.
info
(
"Batch size : {}"
.
format
(
batch_size
))
loss
,
new_mems
=
model
.
transformer
(
dec_inp
=
inp
,
target
=
tgt
,
mems
=
mems
,
n_token
=
n_token
,
n_layer
=
FLAGS
.
n_layer
,
d_model
=
FLAGS
.
d_model
,
d_embed
=
FLAGS
.
d_embed
,
n_head
=
FLAGS
.
n_head
,
d_head
=
FLAGS
.
d_head
,
d_inner
=
FLAGS
.
d_inner
,
dropout
=
FLAGS
.
dropout
,
dropatt
=
FLAGS
.
dropatt
,
initializer
=
initializer
,
is_training
=
is_training
,
mem_len
=
FLAGS
.
mem_len
,
cutoffs
=
cutoffs
,
div_val
=
FLAGS
.
div_val
,
tie_projs
=
tie_projs
,
input_perms
=
inp_perms
,
target_perms
=
tgt_perms
,
head_target
=
head_tgt
,
same_length
=
FLAGS
.
same_length
,
clamp_len
=
FLAGS
.
clamp_len
,
use_tpu
=
FLAGS
.
use_tpu
,
untie_r
=
FLAGS
.
untie_r
,
proj_same_dim
=
FLAGS
.
proj_same_dim
)
total_loss
=
tf
.
reduce_mean
(
loss
)
if
mode
==
tf
.
estimator
.
ModeKeys
.
EVAL
:
if
FLAGS
.
use_tpu
:
with
tf
.
colocate_with
(
total_loss
):
total_loss
=
tf
.
contrib
.
tpu
.
cross_replica_sum
(
total_loss
)
\
/
FLAGS
.
num_hosts
/
FLAGS
.
num_core_per_host
metric_loss
=
tf
.
tile
(
tf
.
reshape
(
total_loss
,
[
1
,
1
]),
[
batch_size
,
1
])
eval_spec
=
tf
.
contrib
.
tpu
.
TPUEstimatorSpec
(
mode
=
mode
,
loss
=
total_loss
,
eval_metrics
=
(
metric_fn
,
[
metric_loss
]))
eval_spec
.
cache
=
new_mems
return
eval_spec
# Configuring the optimization step.
global_step
=
tf
.
train
.
get_global_step
()
# increase the learning rate linearly
if
FLAGS
.
warmup_steps
>
0
:
warmup_lr
=
tf
.
to_float
(
global_step
)
/
tf
.
to_float
(
FLAGS
.
warmup_steps
)
\
*
FLAGS
.
learning_rate
else
:
warmup_lr
=
0.0
# number of parameters
num_params
=
np
.
sum
([
np
.
prod
(
v
.
shape
)
for
v
in
tf
.
trainable_variables
()])
tf
.
logging
.
info
(
"#params: {}"
.
format
(
num_params
))
# format_str = '{{:<{0}s}}\t{{}}'.format(
# max([len(v.name) for v in tf.trainable_variables()]))
# for v in tf.trainable_variables():
# tf.logging.info(format_str.format(v.name, v.get_shape()))
# decay the learning rate using the cosine schedule
decay_lr
=
tf
.
train
.
cosine_decay
(
FLAGS
.
learning_rate
,
global_step
=
global_step
-
FLAGS
.
warmup_steps
,
decay_steps
=
FLAGS
.
train_steps
-
FLAGS
.
warmup_steps
,
alpha
=
FLAGS
.
min_lr_ratio
)
learning_rate
=
tf
.
where
(
global_step
<
FLAGS
.
warmup_steps
,
warmup_lr
,
decay_lr
)
if
FLAGS
.
use_tpu
:
optimizer
=
tf
.
contrib
.
tpu
.
CrossShardOptimizer
(
tf
.
train
.
AdamOptimizer
(
learning_rate
=
learning_rate
))
#GradientDescentOptimizer
else
:
optimizer
=
tf
.
train
.
AdamOptimizer
(
learning_rate
=
learning_rate
)
grads_and_vars
=
optimizer
.
compute_gradients
(
total_loss
)
gradients
,
variables
=
zip
(
*
grads_and_vars
)
clipped
,
_
=
tf
.
clip_by_global_norm
(
gradients
,
FLAGS
.
clip
)
train_op
=
optimizer
.
apply_gradients
(
zip
(
clipped
,
variables
),
global_step
=
tf
.
train
.
get_global_step
())
# Constucting TPUEstimatorSpec with cache.
train_spec
=
tf
.
contrib
.
tpu
.
TPUEstimatorSpec
(
mode
=
mode
,
loss
=
total_loss
,
train_op
=
train_op
)
if
FLAGS
.
mem_len
<
FLAGS
.
tgt_len
:
new_mems
=
[
new_mems
[:
FLAGS
.
mem_len
]
for
mem_t
in
new_mems
]
train_spec
.
cache
=
new_mems
return
train_spec
return
model_fn
def
get_cache_fn
(
mem_len
):
def
cache_fn
(
batch_size
):
mems
=
[]
for
l
in
xrange
(
FLAGS
.
n_layer
):
if
mem_len
>
0
:
mems
.
append
(
tf
.
zeros
([
mem_len
,
batch_size
,
FLAGS
.
d_model
],
dtype
=
tf
.
float32
))
else
:
mems
.
append
(
tf
.
zeros
([
mem_len
],
dtype
=
tf
.
float32
))
return
mems
return
cache_fn
def
main
(
unused_argv
):
del
unused_argv
# Unused
tf
.
logging
.
set_verbosity
(
tf
.
logging
.
INFO
)
# Get corpus info
corpus_info
=
data_utils
.
get_corpus_info
(
FLAGS
.
corpus_info_path
)
n_token
=
corpus_info
[
"vocab_size"
]
cutoffs
=
corpus_info
[
"cutoffs"
][
1
:
-
1
]
if
FLAGS
.
save_steps
==
0
:
FLAGS
.
save_steps
=
None
if
not
FLAGS
.
do_eval_only
:
# Get train input function
train_input_fn
,
train_record_info
=
data_utils
.
get_input_fn
(
record_info_dir
=
FLAGS
.
record_info_dir
,
split
=
"train"
,
per_host_bsz
=
FLAGS
.
train_batch_size
//
FLAGS
.
num_hosts
,
tgt_len
=
FLAGS
.
tgt_len
,
num_core_per_host
=
FLAGS
.
num_core_per_host
,
num_hosts
=
FLAGS
.
num_hosts
,
use_tpu
=
FLAGS
.
use_tpu
)
train_bin_sizes
=
train_record_info
[
"bin_sizes"
]
num_train_batch
=
train_record_info
[
"num_batch"
]
# Get train cache function
train_cache_fn
=
get_cache_fn
(
FLAGS
.
mem_len
)
else
:
train_bin_sizes
=
[]
num_train_batch
=
None
train_cache_fn
=
None
if
FLAGS
.
do_eval
or
FLAGS
.
do_eval_only
:
assert
FLAGS
.
num_hosts
==
1
# Get eval input function
eval_input_fn
,
eval_record_info
=
data_utils
.
get_input_fn
(
record_info_dir
=
FLAGS
.
record_info_dir
,
split
=
FLAGS
.
eval_split
,
per_host_bsz
=
FLAGS
.
eval_batch_size
//
FLAGS
.
num_hosts
,
tgt_len
=
FLAGS
.
tgt_len
,
num_core_per_host
=
FLAGS
.
num_core_per_host
,
num_hosts
=
FLAGS
.
num_hosts
,
use_tpu
=
FLAGS
.
use_tpu
)
eval_bin_sizes
=
eval_record_info
[
"bin_sizes"
]
num_eval_batch
=
eval_record_info
[
"num_batch"
]
if
FLAGS
.
max_eval_batch
>
0
:
num_eval_batch
=
min
(
FLAGS
.
max_eval_batch
,
num_eval_batch
)
# Get eval cache function
eval_cache_fn
=
get_cache_fn
(
FLAGS
.
mem_len
)
model_fn
=
get_model_fn
(
n_token
,
cutoffs
,
train_bin_sizes
,
eval_bin_sizes
)
else
:
eval_cache_fn
=
None
model_fn
=
get_model_fn
(
n_token
,
cutoffs
,
train_bin_sizes
,
[])
##### Create estimator
# TPU Configuration
tpu_cluster_resolver
=
tf
.
contrib
.
cluster_resolver
.
TPUClusterResolver
(
FLAGS
.
tpu
,
zone
=
FLAGS
.
tpu_zone
,
project
=
FLAGS
.
gcp_project
)
per_host_input
=
tf
.
contrib
.
tpu
.
InputPipelineConfig
.
PER_HOST_V2
run_config
=
tf
.
contrib
.
tpu
.
RunConfig
(
cluster
=
tpu_cluster_resolver
,
model_dir
=
FLAGS
.
model_dir
,
session_config
=
tf
.
ConfigProto
(
allow_soft_placement
=
True
,
log_device_placement
=
True
),
tpu_config
=
tf
.
contrib
.
tpu
.
TPUConfig
(
iterations_per_loop
=
FLAGS
.
iterations
,
num_shards
=
FLAGS
.
num_core_per_host
*
FLAGS
.
num_hosts
,
per_host_input_for_training
=
per_host_input
),
keep_checkpoint_max
=
100000
,
# effectively save all checkpoints
save_checkpoints_secs
=
None
,
save_checkpoints_steps
=
FLAGS
.
save_steps
)
# warm start
warm_start_from
=
None
if
FLAGS
.
warm_start_path
is
not
None
:
warm_start_from
=
tf
.
estimator
.
WarmStartSettings
(
ckpt_to_initialize_from
=
FLAGS
.
warm_start_path
)
# TPU Estimator
estimator
=
tpu_estimator
.
TPUEstimator
(
model_fn
=
model_fn
,
train_cache_fn
=
train_cache_fn
,
eval_cache_fn
=
eval_cache_fn
,
use_tpu
=
FLAGS
.
use_tpu
,
config
=
run_config
,
params
=
{
"data_dir"
:
FLAGS
.
data_dir
,
"track_mean"
:
FLAGS
.
track_mean
},
train_batch_size
=
FLAGS
.
train_batch_size
,
eval_batch_size
=
FLAGS
.
eval_batch_size
,
warm_start_from
=
warm_start_from
)
if
FLAGS
.
do_eval_only
:
if
FLAGS
.
eval_ckpt_path
is
not
None
:
ret
=
estimator
.
evaluate
(
input_fn
=
eval_input_fn
,
steps
=
num_eval_batch
,
checkpoint_path
=
FLAGS
.
eval_ckpt_path
)
tf
.
logging
.
info
(
"="
*
200
)
log_str
=
"Eval results | "
for
key
,
val
in
ret
.
items
():
log_str
+=
"{} {} | "
.
format
(
key
,
val
)
tf
.
logging
.
info
(
log_str
)
tf
.
logging
.
info
(
"="
*
200
)
else
:
ckpt_state
=
tf
.
train
.
get_checkpoint_state
(
FLAGS
.
model_dir
)
eval_results
=
[]
for
eval_checkpoint
in
ckpt_state
.
all_model_checkpoint_paths
:
if
not
exists
(
eval_checkpoint
+
".index"
):
continue
global_step
=
int
(
eval_checkpoint
.
split
(
"-"
)[
-
1
])
if
global_step
<
FLAGS
.
start_eval_steps
or
global_step
>
FLAGS
.
train_steps
:
continue
ret
=
estimator
.
evaluate
(
input_fn
=
eval_input_fn
,
steps
=
num_eval_batch
,
checkpoint_path
=
eval_checkpoint
)
eval_results
.
append
(
ret
)
eval_results
.
sort
(
key
=
lambda
x
:
x
[
"perplexity"
])
tf
.
logging
.
info
(
"="
*
200
)
log_str
=
"Best results | "
for
key
,
val
in
eval_results
[
0
].
items
():
log_str
+=
"{} {} | "
.
format
(
key
,
val
)
tf
.
logging
.
info
(
log_str
)
tf
.
logging
.
info
(
"="
*
200
)
else
:
if
not
FLAGS
.
do_eval
:
estimator
.
train
(
input_fn
=
train_input_fn
,
steps
=
FLAGS
.
train_steps
)
else
:
for
step
in
range
(
0
,
FLAGS
.
train_steps
,
num_train_batch
):
train_steps
=
min
(
FLAGS
.
train_steps
-
step
,
num_train_batch
)
estimator
.
train
(
input_fn
=
train_input_fn
,
steps
=
train_steps
)
estimator
.
evaluate
(
input_fn
=
eval_input_fn
,
steps
=
num_eval_batch
)
if
__name__
==
"__main__"
:
tf
.
app
.
run
()
TensorFlow/NLP/transformer-xl-master/train_fp16.py
0 → 100644
View file @
cb8dde1c
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
os
import
math
import
time
from
absl
import
flags
import
absl.logging
as
_logging
# pylint: disable=unused-import
import
tensorflow
as
tf
import
model
import
data_utils
from
gpu_utils
import
assign_to_gpu
,
average_grads_and_vars
import
numpy
as
np
# GPU config
flags
.
DEFINE_integer
(
"num_hosts"
,
default
=
1
,
help
=
"Number of TPU hosts"
)
flags
.
DEFINE_integer
(
"num_core_per_host"
,
default
=
8
,
help
=
"Number of cores per host"
)
# Experiment (data/checkpoint/directory) config
flags
.
DEFINE_string
(
"data_dir"
,
default
=
""
,
help
=
"Path to tf-records directory."
)
flags
.
DEFINE_string
(
"record_info_dir"
,
default
=
""
,
help
=
"Path to local directory containing filenames.txt."
)
flags
.
DEFINE_string
(
"corpus_info_path"
,
default
=
""
,
help
=
"Path to corpus-info.json file."
)
flags
.
DEFINE_string
(
"model_dir"
,
default
=
None
,
help
=
"Estimator model_dir."
)
flags
.
DEFINE_bool
(
"do_train"
,
default
=
True
,
help
=
"Whether to run training."
)
flags
.
DEFINE_bool
(
"do_eval"
,
default
=
False
,
help
=
"Whether to run eval on the dev set."
)
flags
.
DEFINE_string
(
"eval_ckpt_path"
,
None
,
help
=
"Checkpoint path for do_test evaluation."
"If set, model_dir will be ignored."
"If unset, will use the latest ckpt in model_dir."
)
flags
.
DEFINE_string
(
"warm_start_path"
,
None
,
help
=
"Checkpoint path for warm start."
"If set, will clear Adam states."
"Note that the new model_dir should be different"
" from warm_start_path."
)
# Optimization config
flags
.
DEFINE_float
(
"learning_rate"
,
default
=
2.5e-4
,
help
=
"Maximum learning rate."
)
flags
.
DEFINE_float
(
"clip"
,
default
=
0.25
,
help
=
"Gradient clipping value."
)
# for cosine decay
flags
.
DEFINE_float
(
"min_lr_ratio"
,
default
=
0.004
,
help
=
"Minimum ratio learning rate."
)
flags
.
DEFINE_integer
(
"warmup_steps"
,
default
=
0
,
help
=
"Number of steps for linear lr warmup."
)
# Training config
flags
.
DEFINE_integer
(
"train_batch_size"
,
default
=
60
,
help
=
"Size of train batch."
)
flags
.
DEFINE_integer
(
"eval_batch_size"
,
default
=
60
,
help
=
"Size of valid batch."
)
flags
.
DEFINE_integer
(
"train_steps"
,
default
=
100000
,
help
=
"Total number of training steps."
)
flags
.
DEFINE_integer
(
"iterations"
,
default
=
500
,
help
=
"Number of iterations per repeat loop."
)
flags
.
DEFINE_integer
(
"save_steps"
,
default
=
10000
,
help
=
"number of steps for model checkpointing."
)
# Evaluation config
flags
.
DEFINE_bool
(
"do_test"
,
default
=
False
,
help
=
"Run on the test set."
)
flags
.
DEFINE_integer
(
"max_eval_batch"
,
default
=-
1
,
help
=
"Set -1 to turn off. Only used in test mode."
)
flags
.
DEFINE_bool
(
"do_eval_only"
,
default
=
False
,
help
=
"Run evaluation only."
)
flags
.
DEFINE_integer
(
"start_eval_steps"
,
default
=
10000
,
help
=
"Which checkpoint to start with in `do_eval_only` mode."
)
flags
.
DEFINE_string
(
"eval_split"
,
"valid"
,
help
=
"Which data split to evaluate."
)
# Model config
flags
.
DEFINE_integer
(
"tgt_len"
,
default
=
70
,
help
=
"Number of steps to predict"
)
flags
.
DEFINE_integer
(
"mem_len"
,
default
=
70
,
help
=
"Number of steps to cache"
)
flags
.
DEFINE_bool
(
"same_length"
,
default
=
False
,
help
=
"Same length attention"
)
flags
.
DEFINE_integer
(
"clamp_len"
,
default
=-
1
,
help
=
"Clamp length"
)
flags
.
DEFINE_integer
(
"n_layer"
,
default
=
6
,
help
=
"Number of layers."
)
flags
.
DEFINE_integer
(
"d_model"
,
default
=
500
,
help
=
"Dimension of the model."
)
flags
.
DEFINE_integer
(
"d_embed"
,
default
=
500
,
help
=
"Dimension of the embeddings."
)
flags
.
DEFINE_integer
(
"n_head"
,
default
=
10
,
help
=
"Number of attention heads."
)
flags
.
DEFINE_integer
(
"d_head"
,
default
=
50
,
help
=
"Dimension of each attention head."
)
flags
.
DEFINE_integer
(
"d_inner"
,
default
=
1000
,
help
=
"Dimension of inner hidden size in positionwise feed-forward."
)
flags
.
DEFINE_float
(
"dropout"
,
default
=
0.1
,
help
=
"Dropout rate."
)
flags
.
DEFINE_float
(
"dropatt"
,
default
=
0.1
,
help
=
"Attention dropout rate."
)
flags
.
DEFINE_bool
(
"untie_r"
,
default
=
False
,
help
=
"untie r_w_bias and r_r_bias"
)
# Adaptive Softmax / Embedding
flags
.
DEFINE_bool
(
"tie_weight"
,
default
=
True
,
help
=
"Tie embedding and softmax weight."
)
flags
.
DEFINE_integer
(
"div_val"
,
default
=
1
,
help
=
"Divide the embedding size by this val for each bin"
)
flags
.
DEFINE_bool
(
"proj_share_all_but_first"
,
default
=
False
,
help
=
"True to share all but first projs, False not to share."
)
flags
.
DEFINE_bool
(
"proj_same_dim"
,
default
=
True
,
help
=
"Project the bin with the same dimension."
)
# Parameter initialization
flags
.
DEFINE_enum
(
"init"
,
default
=
"normal"
,
enum_values
=
[
"normal"
,
"uniform"
],
help
=
"Initialization method."
)
flags
.
DEFINE_float
(
"init_std"
,
default
=
0.02
,
help
=
"Initialization std when init is normal."
)
flags
.
DEFINE_float
(
"proj_init_std"
,
default
=
0.01
,
help
=
"Initialization std for embedding projection."
)
flags
.
DEFINE_float
(
"init_range"
,
default
=
0.1
,
help
=
"Initialization std when init is uniform."
)
FLAGS
=
flags
.
FLAGS
def
get_model_fn
(
n_token
,
cutoffs
):
def
model_fn
(
inp
,
tgt
,
mems
,
is_training
):
inp
=
tf
.
transpose
(
inp
,
[
1
,
0
])
tgt
=
tf
.
transpose
(
tgt
,
[
1
,
0
])
if
FLAGS
.
init
==
"uniform"
:
initializer
=
tf
.
initializers
.
random_uniform
(
minval
=-
FLAGS
.
init_range
,
maxval
=
FLAGS
.
init_range
,
seed
=
None
)
elif
FLAGS
.
init
==
"normal"
:
initializer
=
tf
.
initializers
.
random_normal
(
stddev
=
FLAGS
.
init_std
,
seed
=
None
)
proj_initializer
=
tf
.
initializers
.
random_normal
(
stddev
=
FLAGS
.
proj_init_std
,
seed
=
None
)
tie_projs
=
[
False
for
_
in
range
(
len
(
cutoffs
)
+
1
)]
if
FLAGS
.
proj_share_all_but_first
:
for
i
in
range
(
1
,
len
(
tie_projs
)):
tie_projs
[
i
]
=
True
loss
,
new_mems
=
model
.
transformer
(
dec_inp
=
inp
,
target
=
tgt
,
mems
=
mems
,
n_token
=
n_token
,
n_layer
=
FLAGS
.
n_layer
,
d_model
=
FLAGS
.
d_model
,
d_embed
=
FLAGS
.
d_embed
,
n_head
=
FLAGS
.
n_head
,
d_head
=
FLAGS
.
d_head
,
d_inner
=
FLAGS
.
d_inner
,
dropout
=
FLAGS
.
dropout
,
dropatt
=
FLAGS
.
dropatt
,
initializer
=
initializer
,
proj_initializer
=
proj_initializer
,
is_training
=
is_training
,
mem_len
=
FLAGS
.
mem_len
,
cutoffs
=
cutoffs
,
div_val
=
FLAGS
.
div_val
,
tie_projs
=
tie_projs
,
input_perms
=
None
,
target_perms
=
None
,
head_target
=
None
,
same_length
=
FLAGS
.
same_length
,
clamp_len
=
FLAGS
.
clamp_len
,
use_tpu
=
False
,
untie_r
=
FLAGS
.
untie_r
,
proj_same_dim
=
FLAGS
.
proj_same_dim
)
# number of parameters
num_params
=
sum
([
np
.
prod
(
v
.
shape
)
for
v
in
tf
.
trainable_variables
()])
tf
.
logging
.
info
(
'#params: {}'
.
format
(
num_params
))
# format_str = '{{:<{0}s}}\t{{}}'.format(
# max([len(v.name) for v in tf.trainable_variables()]))
# for v in tf.trainable_variables():
# tf.logging.info(format_str.format(v.name, v.get_shape()))
if
is_training
:
all_vars
=
tf
.
trainable_variables
()
grads
=
tf
.
gradients
(
loss
,
all_vars
)
grads_and_vars
=
list
(
zip
(
grads
,
all_vars
))
return
loss
,
new_mems
,
grads_and_vars
else
:
return
loss
,
new_mems
return
model_fn
def
single_core_graph
(
n_token
,
cutoffs
,
is_training
,
inp
,
tgt
,
mems
):
model_fn
=
get_model_fn
(
n_token
=
n_token
,
cutoffs
=
cutoffs
)
model_ret
=
model_fn
(
inp
=
inp
,
tgt
=
tgt
,
mems
=
mems
,
is_training
=
is_training
)
return
model_ret
def
train
(
n_token
,
cutoffs
,
ps_device
):
##### Get input function and model function
train_input_fn
,
train_record_info
=
data_utils
.
get_input_fn
(
record_info_dir
=
FLAGS
.
record_info_dir
,
split
=
"train"
,
per_host_bsz
=
FLAGS
.
train_batch_size
,
tgt_len
=
FLAGS
.
tgt_len
,
num_core_per_host
=
FLAGS
.
num_core_per_host
,
num_hosts
=
1
,
use_tpu
=
False
)
tf
.
logging
.
info
(
"num of batches {}"
.
format
(
train_record_info
[
"num_batch"
]))
##### Create computational graph
train_set
=
train_input_fn
({
"batch_size"
:
FLAGS
.
train_batch_size
,
"data_dir"
:
FLAGS
.
data_dir
})
input_feed
,
label_feed
=
train_set
.
make_one_shot_iterator
().
get_next
()
inputs
=
tf
.
split
(
input_feed
,
FLAGS
.
num_core_per_host
,
0
)
labels
=
tf
.
split
(
label_feed
,
FLAGS
.
num_core_per_host
,
0
)
per_core_bsz
=
FLAGS
.
train_batch_size
//
FLAGS
.
num_core_per_host
tower_mems
,
tower_losses
,
tower_new_mems
,
tower_grads_and_vars
=
[],
[],
[],
[]
for
i
in
range
(
FLAGS
.
num_core_per_host
):
reuse
=
True
if
i
>
0
else
None
with
tf
.
device
(
assign_to_gpu
(
i
,
ps_device
)),
\
tf
.
variable_scope
(
tf
.
get_variable_scope
(),
reuse
=
reuse
):
mems_i
=
[
tf
.
placeholder
(
tf
.
float32
,
[
FLAGS
.
mem_len
,
per_core_bsz
,
FLAGS
.
d_model
])
for
_
in
range
(
FLAGS
.
n_layer
)]
loss_i
,
new_mems_i
,
grads_and_vars_i
=
single_core_graph
(
n_token
=
n_token
,
cutoffs
=
cutoffs
,
is_training
=
True
,
inp
=
inputs
[
i
],
tgt
=
labels
[
i
],
mems
=
mems_i
)
tower_mems
.
append
(
mems_i
)
tower_losses
.
append
(
loss_i
)
tower_new_mems
.
append
(
new_mems_i
)
tower_grads_and_vars
.
append
(
grads_and_vars_i
)
## average losses and gradients across towers
if
len
(
tower_losses
)
>
1
:
loss
=
tf
.
add_n
(
tower_losses
)
/
len
(
tower_losses
)
grads_and_vars
=
average_grads_and_vars
(
tower_grads_and_vars
)
else
:
loss
=
tower_losses
[
0
]
grads_and_vars
=
tower_grads_and_vars
[
0
]
grads
,
all_vars
=
zip
(
*
grads_and_vars
)
## clip gradient
clipped
,
gnorm
=
tf
.
clip_by_global_norm
(
grads
,
FLAGS
.
clip
)
grads_and_vars
=
list
(
zip
(
clipped
,
all_vars
))
## configure the optimizer
global_step
=
tf
.
train
.
get_or_create_global_step
()
# warmup stage: increase the learning rate linearly
if
FLAGS
.
warmup_steps
>
0
:
warmup_lr
=
tf
.
to_float
(
global_step
)
/
tf
.
to_float
(
FLAGS
.
warmup_steps
)
\
*
FLAGS
.
learning_rate
else
:
warmup_lr
=
0.0
# decay stage: decay the learning rate using the cosine schedule
decay_lr
=
tf
.
train
.
cosine_decay
(
FLAGS
.
learning_rate
,
global_step
=
global_step
-
FLAGS
.
warmup_steps
,
decay_steps
=
FLAGS
.
train_steps
-
FLAGS
.
warmup_steps
,
alpha
=
FLAGS
.
min_lr_ratio
)
# choose warmup or decay
learning_rate
=
tf
.
where
(
global_step
<
FLAGS
.
warmup_steps
,
warmup_lr
,
decay_lr
)
# get the train op
optimizer
=
tf
.
train
.
AdamOptimizer
(
learning_rate
=
learning_rate
)
train_op
=
optimizer
.
apply_gradients
(
grads_and_vars
,
global_step
)
##### Training loop
tower_mems_np
=
[
[
np
.
zeros
([
FLAGS
.
mem_len
,
per_core_bsz
,
FLAGS
.
d_model
],
dtype
=
np
.
float32
)
for
layer
in
range
(
FLAGS
.
n_layer
)]
for
core
in
range
(
FLAGS
.
num_core_per_host
)
]
saver
=
tf
.
train
.
Saver
()
with
tf
.
Session
(
config
=
tf
.
ConfigProto
(
allow_soft_placement
=
True
))
as
sess
:
sess
.
run
(
tf
.
global_variables_initializer
())
if
FLAGS
.
warm_start_path
is
not
None
:
tf
.
logging
.
info
(
"warm start from {}"
.
format
(
FLAGS
.
warm_start_path
))
saver
.
restore
(
sess
,
FLAGS
.
warm_start_path
)
fetches
=
[
loss
,
tower_new_mems
,
global_step
,
gnorm
,
learning_rate
,
train_op
]
total_loss
,
prev_step
=
0.
,
-
1
while
True
:
feed_dict
=
{}
for
i
in
range
(
FLAGS
.
num_core_per_host
):
for
m
,
m_np
in
zip
(
tower_mems
[
i
],
tower_mems_np
[
i
]):
feed_dict
[
m
]
=
m_np
#改
s_time
=
time
.
time
()
fetched
=
sess
.
run
(
fetches
,
feed_dict
=
feed_dict
)
e_time
=
time
.
time
()
global_step_s
=
1
/
(
e_time
-
s_time
)
tf
.
logging
.
info
(
"global_step/sec : {}"
.
format
(
global_step_s
))
loss_np
,
tower_mems_np
,
curr_step
=
fetched
[:
3
]
total_loss
+=
loss_np
if
curr_step
>
0
and
curr_step
%
FLAGS
.
iterations
==
0
:
curr_loss
=
total_loss
/
(
curr_step
-
prev_step
)
tf
.
logging
.
info
(
"[{}] | gnorm {:.2f} lr {:8.6f} "
"| loss {:.2f} | pplx {:>7.2f}, bpc {:>7.4f}"
.
format
(
curr_step
,
fetched
[
-
3
],
fetched
[
-
2
],
curr_loss
,
math
.
exp
(
curr_loss
),
curr_loss
/
math
.
log
(
2
)))
total_loss
,
prev_step
=
0.
,
curr_step
if
curr_step
>
0
and
curr_step
%
FLAGS
.
save_steps
==
0
:
save_path
=
os
.
path
.
join
(
FLAGS
.
model_dir
,
"model.ckpt"
)
saver
.
save
(
sess
,
save_path
)
tf
.
logging
.
info
(
"Model saved in path: {}"
.
format
(
save_path
))
if
curr_step
==
FLAGS
.
train_steps
:
break
def
evaluate
(
n_token
,
cutoffs
,
ps_device
):
##### Get input function and model function
eval_input_fn
,
eval_record_info
=
data_utils
.
get_input_fn
(
record_info_dir
=
FLAGS
.
record_info_dir
,
split
=
FLAGS
.
eval_split
,
per_host_bsz
=
FLAGS
.
eval_batch_size
,
tgt_len
=
FLAGS
.
tgt_len
,
num_core_per_host
=
FLAGS
.
num_core_per_host
,
num_hosts
=
1
,
use_tpu
=
False
)
num_batch
=
eval_record_info
[
"num_batch"
]
if
FLAGS
.
max_eval_batch
>
0
:
num_batch
=
FLAGS
.
max_eval_batch
tf
.
logging
.
info
(
"num of batches {}"
.
format
(
num_batch
))
##### Create computational graph
eval_set
=
eval_input_fn
({
"batch_size"
:
FLAGS
.
eval_batch_size
,
"data_dir"
:
FLAGS
.
data_dir
})
input_feed
,
label_feed
=
eval_set
.
make_one_shot_iterator
().
get_next
()
inputs
=
tf
.
split
(
input_feed
,
FLAGS
.
num_core_per_host
,
0
)
labels
=
tf
.
split
(
label_feed
,
FLAGS
.
num_core_per_host
,
0
)
per_core_bsz
=
FLAGS
.
eval_batch_size
//
FLAGS
.
num_core_per_host
tower_mems
,
tower_losses
,
tower_new_mems
=
[],
[],
[]
for
i
in
range
(
FLAGS
.
num_core_per_host
):
with
tf
.
device
(
assign_to_gpu
(
i
,
ps_device
)),
\
tf
.
variable_scope
(
tf
.
get_variable_scope
(),
reuse
=
tf
.
AUTO_REUSE
):
mems_i
=
[
tf
.
placeholder
(
tf
.
float32
,
[
FLAGS
.
mem_len
,
per_core_bsz
,
FLAGS
.
d_model
])
for
_
in
range
(
FLAGS
.
n_layer
)]
loss_i
,
new_mems_i
=
single_core_graph
(
n_token
=
n_token
,
cutoffs
=
cutoffs
,
is_training
=
False
,
inp
=
inputs
[
i
],
tgt
=
labels
[
i
],
mems
=
mems_i
)
tower_mems
.
append
(
mems_i
)
tower_losses
.
append
(
loss_i
)
tower_new_mems
.
append
(
new_mems_i
)
## sum losses across towers
if
len
(
tower_losses
)
>
1
:
loss
=
tf
.
add_n
(
tower_losses
)
/
len
(
tower_losses
)
else
:
loss
=
tower_losses
[
0
]
##### Evaluation loop
tower_mems_np
=
[
[
np
.
zeros
([
FLAGS
.
mem_len
,
per_core_bsz
,
FLAGS
.
d_model
],
dtype
=
np
.
float32
)
for
layer
in
range
(
FLAGS
.
n_layer
)]
for
core
in
range
(
FLAGS
.
num_core_per_host
)
]
saver
=
tf
.
train
.
Saver
()
with
tf
.
Session
(
config
=
tf
.
ConfigProto
(
allow_soft_placement
=
True
))
as
sess
:
sess
.
run
(
tf
.
global_variables_initializer
())
if
FLAGS
.
eval_ckpt_path
is
None
:
eval_ckpt_path
=
tf
.
train
.
latest_checkpoint
(
FLAGS
.
model_dir
)
else
:
eval_ckpt_path
=
FLAGS
.
eval_ckpt_path
tf
.
logging
.
info
(
"Evaluate {}"
.
format
(
eval_ckpt_path
))
saver
.
restore
(
sess
,
eval_ckpt_path
)
fetches
=
[
loss
,
tower_new_mems
,
tf
.
size
(
label_feed
)]
format_str
=
" >> processing batch {{:{0}d}}/{{:{0}d}} .."
.
format
(
len
(
str
(
num_batch
)))
total_loss
,
total_cnt
=
0
,
0
for
step
in
range
(
num_batch
):
if
step
%
(
num_batch
//
10
)
==
0
:
tf
.
logging
.
info
(
format_str
.
format
(
step
,
num_batch
))
feed_dict
=
{}
for
i
in
range
(
FLAGS
.
num_core_per_host
):
for
m
,
m_np
in
zip
(
tower_mems
[
i
],
tower_mems_np
[
i
]):
feed_dict
[
m
]
=
m_np
fetched
=
sess
.
run
(
fetches
,
feed_dict
=
feed_dict
)
loss_np
,
tower_mems_np
,
cnt_np
=
fetched
[:
3
]
total_loss
+=
loss_np
*
cnt_np
total_cnt
+=
cnt_np
avg_loss
=
total_loss
/
total_cnt
tf
.
logging
.
info
(
"| loss {:.2f} | pplx {:>7.2f}, bpc {:>7.4f}"
.
format
(
avg_loss
,
math
.
exp
(
avg_loss
),
avg_loss
/
math
.
log
(
2
)))
def
main
(
unused_argv
):
del
unused_argv
# Unused
tf
.
logging
.
set_verbosity
(
tf
.
logging
.
INFO
)
# Get corpus info
corpus_info
=
data_utils
.
get_corpus_info
(
FLAGS
.
corpus_info_path
)
n_token
=
corpus_info
[
"vocab_size"
]
cutoffs
=
corpus_info
[
"cutoffs"
][
1
:
-
1
]
tf
.
logging
.
info
(
"n_token {}"
.
format
(
n_token
))
if
FLAGS
.
do_train
:
train
(
n_token
,
cutoffs
,
"/gpu:0"
)
if
FLAGS
.
do_eval
:
evaluate
(
n_token
,
cutoffs
,
"/gpu:0"
)
if
__name__
==
"__main__"
:
tf
.
app
.
run
()
TensorFlow/NLP/transformer-xl-master/train_gpu.py
0 → 100644
View file @
cb8dde1c
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
os
import
math
import
time
from
absl
import
flags
import
absl.logging
as
_logging
# pylint: disable=unused-import
import
tensorflow
as
tf
import
model
import
data_utils
from
gpu_utils
import
assign_to_gpu
,
average_grads_and_vars
import
numpy
as
np
# GPU config
flags
.
DEFINE_integer
(
"num_hosts"
,
default
=
1
,
help
=
"Number of TPU hosts"
)
flags
.
DEFINE_integer
(
"num_core_per_host"
,
default
=
8
,
help
=
"Number of cores per host"
)
# Experiment (data/checkpoint/directory) config
flags
.
DEFINE_string
(
"data_dir"
,
default
=
""
,
help
=
"Path to tf-records directory."
)
flags
.
DEFINE_string
(
"record_info_dir"
,
default
=
""
,
help
=
"Path to local directory containing filenames.txt."
)
flags
.
DEFINE_string
(
"corpus_info_path"
,
default
=
""
,
help
=
"Path to corpus-info.json file."
)
flags
.
DEFINE_string
(
"model_dir"
,
default
=
None
,
help
=
"Estimator model_dir."
)
flags
.
DEFINE_bool
(
"do_train"
,
default
=
True
,
help
=
"Whether to run training."
)
flags
.
DEFINE_bool
(
"do_eval"
,
default
=
False
,
help
=
"Whether to run eval on the dev set."
)
flags
.
DEFINE_string
(
"eval_ckpt_path"
,
None
,
help
=
"Checkpoint path for do_test evaluation."
"If set, model_dir will be ignored."
"If unset, will use the latest ckpt in model_dir."
)
flags
.
DEFINE_string
(
"warm_start_path"
,
None
,
help
=
"Checkpoint path for warm start."
"If set, will clear Adam states."
"Note that the new model_dir should be different"
" from warm_start_path."
)
# Optimization config
flags
.
DEFINE_float
(
"learning_rate"
,
default
=
2.5e-4
,
help
=
"Maximum learning rate."
)
flags
.
DEFINE_float
(
"clip"
,
default
=
0.25
,
help
=
"Gradient clipping value."
)
# for cosine decay
flags
.
DEFINE_float
(
"min_lr_ratio"
,
default
=
0.004
,
help
=
"Minimum ratio learning rate."
)
flags
.
DEFINE_integer
(
"warmup_steps"
,
default
=
0
,
help
=
"Number of steps for linear lr warmup."
)
# Training config
flags
.
DEFINE_integer
(
"train_batch_size"
,
default
=
60
,
help
=
"Size of train batch."
)
flags
.
DEFINE_integer
(
"eval_batch_size"
,
default
=
60
,
help
=
"Size of valid batch."
)
flags
.
DEFINE_integer
(
"train_steps"
,
default
=
100000
,
help
=
"Total number of training steps."
)
flags
.
DEFINE_integer
(
"iterations"
,
default
=
500
,
help
=
"Number of iterations per repeat loop."
)
flags
.
DEFINE_integer
(
"save_steps"
,
default
=
10000
,
help
=
"number of steps for model checkpointing."
)
# Evaluation config
flags
.
DEFINE_bool
(
"do_test"
,
default
=
False
,
help
=
"Run on the test set."
)
flags
.
DEFINE_integer
(
"max_eval_batch"
,
default
=-
1
,
help
=
"Set -1 to turn off. Only used in test mode."
)
flags
.
DEFINE_bool
(
"do_eval_only"
,
default
=
False
,
help
=
"Run evaluation only."
)
flags
.
DEFINE_integer
(
"start_eval_steps"
,
default
=
10000
,
help
=
"Which checkpoint to start with in `do_eval_only` mode."
)
flags
.
DEFINE_string
(
"eval_split"
,
"valid"
,
help
=
"Which data split to evaluate."
)
# Model config
flags
.
DEFINE_integer
(
"tgt_len"
,
default
=
70
,
help
=
"Number of steps to predict"
)
flags
.
DEFINE_integer
(
"mem_len"
,
default
=
70
,
help
=
"Number of steps to cache"
)
flags
.
DEFINE_bool
(
"same_length"
,
default
=
False
,
help
=
"Same length attention"
)
flags
.
DEFINE_integer
(
"clamp_len"
,
default
=-
1
,
help
=
"Clamp length"
)
flags
.
DEFINE_integer
(
"n_layer"
,
default
=
6
,
help
=
"Number of layers."
)
flags
.
DEFINE_integer
(
"d_model"
,
default
=
500
,
help
=
"Dimension of the model."
)
flags
.
DEFINE_integer
(
"d_embed"
,
default
=
500
,
help
=
"Dimension of the embeddings."
)
flags
.
DEFINE_integer
(
"n_head"
,
default
=
10
,
help
=
"Number of attention heads."
)
flags
.
DEFINE_integer
(
"d_head"
,
default
=
50
,
help
=
"Dimension of each attention head."
)
flags
.
DEFINE_integer
(
"d_inner"
,
default
=
1000
,
help
=
"Dimension of inner hidden size in positionwise feed-forward."
)
flags
.
DEFINE_float
(
"dropout"
,
default
=
0.1
,
help
=
"Dropout rate."
)
flags
.
DEFINE_float
(
"dropatt"
,
default
=
0.1
,
help
=
"Attention dropout rate."
)
flags
.
DEFINE_bool
(
"untie_r"
,
default
=
False
,
help
=
"untie r_w_bias and r_r_bias"
)
# Adaptive Softmax / Embedding
flags
.
DEFINE_bool
(
"tie_weight"
,
default
=
True
,
help
=
"Tie embedding and softmax weight."
)
flags
.
DEFINE_integer
(
"div_val"
,
default
=
1
,
help
=
"Divide the embedding size by this val for each bin"
)
flags
.
DEFINE_bool
(
"proj_share_all_but_first"
,
default
=
False
,
help
=
"True to share all but first projs, False not to share."
)
flags
.
DEFINE_bool
(
"proj_same_dim"
,
default
=
True
,
help
=
"Project the bin with the same dimension."
)
# Parameter initialization
flags
.
DEFINE_enum
(
"init"
,
default
=
"normal"
,
enum_values
=
[
"normal"
,
"uniform"
],
help
=
"Initialization method."
)
flags
.
DEFINE_float
(
"init_std"
,
default
=
0.02
,
help
=
"Initialization std when init is normal."
)
flags
.
DEFINE_float
(
"proj_init_std"
,
default
=
0.01
,
help
=
"Initialization std for embedding projection."
)
flags
.
DEFINE_float
(
"init_range"
,
default
=
0.1
,
help
=
"Initialization std when init is uniform."
)
FLAGS
=
flags
.
FLAGS
def
get_model_fn
(
n_token
,
cutoffs
):
def
model_fn
(
inp
,
tgt
,
mems
,
is_training
):
inp
=
tf
.
transpose
(
inp
,
[
1
,
0
])
tgt
=
tf
.
transpose
(
tgt
,
[
1
,
0
])
if
FLAGS
.
init
==
"uniform"
:
initializer
=
tf
.
initializers
.
random_uniform
(
minval
=-
FLAGS
.
init_range
,
maxval
=
FLAGS
.
init_range
,
seed
=
None
)
elif
FLAGS
.
init
==
"normal"
:
initializer
=
tf
.
initializers
.
random_normal
(
stddev
=
FLAGS
.
init_std
,
seed
=
None
)
proj_initializer
=
tf
.
initializers
.
random_normal
(
stddev
=
FLAGS
.
proj_init_std
,
seed
=
None
)
tie_projs
=
[
False
for
_
in
range
(
len
(
cutoffs
)
+
1
)]
if
FLAGS
.
proj_share_all_but_first
:
for
i
in
range
(
1
,
len
(
tie_projs
)):
tie_projs
[
i
]
=
True
loss
,
new_mems
=
model
.
transformer
(
dec_inp
=
inp
,
target
=
tgt
,
mems
=
mems
,
n_token
=
n_token
,
n_layer
=
FLAGS
.
n_layer
,
d_model
=
FLAGS
.
d_model
,
d_embed
=
FLAGS
.
d_embed
,
n_head
=
FLAGS
.
n_head
,
d_head
=
FLAGS
.
d_head
,
d_inner
=
FLAGS
.
d_inner
,
dropout
=
FLAGS
.
dropout
,
dropatt
=
FLAGS
.
dropatt
,
initializer
=
initializer
,
proj_initializer
=
proj_initializer
,
is_training
=
is_training
,
mem_len
=
FLAGS
.
mem_len
,
cutoffs
=
cutoffs
,
div_val
=
FLAGS
.
div_val
,
tie_projs
=
tie_projs
,
input_perms
=
None
,
target_perms
=
None
,
head_target
=
None
,
same_length
=
FLAGS
.
same_length
,
clamp_len
=
FLAGS
.
clamp_len
,
use_tpu
=
False
,
untie_r
=
FLAGS
.
untie_r
,
proj_same_dim
=
FLAGS
.
proj_same_dim
)
# number of parameters
num_params
=
sum
([
np
.
prod
(
v
.
shape
)
for
v
in
tf
.
trainable_variables
()])
tf
.
logging
.
info
(
'#params: {}'
.
format
(
num_params
))
# format_str = '{{:<{0}s}}\t{{}}'.format(
# max([len(v.name) for v in tf.trainable_variables()]))
# for v in tf.trainable_variables():
# tf.logging.info(format_str.format(v.name, v.get_shape()))
if
is_training
:
all_vars
=
tf
.
trainable_variables
()
grads
=
tf
.
gradients
(
loss
,
all_vars
)
grads_and_vars
=
list
(
zip
(
grads
,
all_vars
))
return
loss
,
new_mems
,
grads_and_vars
else
:
return
loss
,
new_mems
return
model_fn
def
single_core_graph
(
n_token
,
cutoffs
,
is_training
,
inp
,
tgt
,
mems
):
model_fn
=
get_model_fn
(
n_token
=
n_token
,
cutoffs
=
cutoffs
)
model_ret
=
model_fn
(
inp
=
inp
,
tgt
=
tgt
,
mems
=
mems
,
is_training
=
is_training
)
return
model_ret
def
train
(
n_token
,
cutoffs
,
ps_device
):
##### Get input function and model function
train_input_fn
,
train_record_info
=
data_utils
.
get_input_fn
(
record_info_dir
=
FLAGS
.
record_info_dir
,
split
=
"train"
,
per_host_bsz
=
FLAGS
.
train_batch_size
,
tgt_len
=
FLAGS
.
tgt_len
,
num_core_per_host
=
FLAGS
.
num_core_per_host
,
num_hosts
=
1
,
use_tpu
=
False
)
tf
.
logging
.
info
(
"num of batches {}"
.
format
(
train_record_info
[
"num_batch"
]))
##### Create computational graph
train_set
=
train_input_fn
({
"batch_size"
:
FLAGS
.
train_batch_size
,
"data_dir"
:
FLAGS
.
data_dir
})
input_feed
,
label_feed
=
train_set
.
make_one_shot_iterator
().
get_next
()
inputs
=
tf
.
split
(
input_feed
,
FLAGS
.
num_core_per_host
,
0
)
labels
=
tf
.
split
(
label_feed
,
FLAGS
.
num_core_per_host
,
0
)
per_core_bsz
=
FLAGS
.
train_batch_size
//
FLAGS
.
num_core_per_host
tower_mems
,
tower_losses
,
tower_new_mems
,
tower_grads_and_vars
=
[],
[],
[],
[]
for
i
in
range
(
FLAGS
.
num_core_per_host
):
reuse
=
True
if
i
>
0
else
None
with
tf
.
device
(
assign_to_gpu
(
i
,
ps_device
)),
\
tf
.
variable_scope
(
tf
.
get_variable_scope
(),
reuse
=
reuse
):
mems_i
=
[
tf
.
placeholder
(
tf
.
float32
,
[
FLAGS
.
mem_len
,
per_core_bsz
,
FLAGS
.
d_model
])
for
_
in
range
(
FLAGS
.
n_layer
)]
loss_i
,
new_mems_i
,
grads_and_vars_i
=
single_core_graph
(
n_token
=
n_token
,
cutoffs
=
cutoffs
,
is_training
=
True
,
inp
=
inputs
[
i
],
tgt
=
labels
[
i
],
mems
=
mems_i
)
tower_mems
.
append
(
mems_i
)
tower_losses
.
append
(
loss_i
)
tower_new_mems
.
append
(
new_mems_i
)
tower_grads_and_vars
.
append
(
grads_and_vars_i
)
## average losses and gradients across towers
if
len
(
tower_losses
)
>
1
:
loss
=
tf
.
add_n
(
tower_losses
)
/
len
(
tower_losses
)
grads_and_vars
=
average_grads_and_vars
(
tower_grads_and_vars
)
else
:
loss
=
tower_losses
[
0
]
grads_and_vars
=
tower_grads_and_vars
[
0
]
grads
,
all_vars
=
zip
(
*
grads_and_vars
)
## clip gradient
clipped
,
gnorm
=
tf
.
clip_by_global_norm
(
grads
,
FLAGS
.
clip
)
grads_and_vars
=
list
(
zip
(
clipped
,
all_vars
))
## configure the optimizer
global_step
=
tf
.
train
.
get_or_create_global_step
()
# warmup stage: increase the learning rate linearly
if
FLAGS
.
warmup_steps
>
0
:
warmup_lr
=
tf
.
to_float
(
global_step
)
/
tf
.
to_float
(
FLAGS
.
warmup_steps
)
\
*
FLAGS
.
learning_rate
else
:
warmup_lr
=
0.0
# decay stage: decay the learning rate using the cosine schedule
decay_lr
=
tf
.
train
.
cosine_decay
(
FLAGS
.
learning_rate
,
global_step
=
global_step
-
FLAGS
.
warmup_steps
,
decay_steps
=
FLAGS
.
train_steps
-
FLAGS
.
warmup_steps
,
alpha
=
FLAGS
.
min_lr_ratio
)
# choose warmup or decay
learning_rate
=
tf
.
where
(
global_step
<
FLAGS
.
warmup_steps
,
warmup_lr
,
decay_lr
)
# get the train op
optimizer
=
tf
.
train
.
AdamOptimizer
(
learning_rate
=
learning_rate
)
train_op
=
optimizer
.
apply_gradients
(
grads_and_vars
,
global_step
)
##### Training loop
tower_mems_np
=
[
[
np
.
zeros
([
FLAGS
.
mem_len
,
per_core_bsz
,
FLAGS
.
d_model
],
dtype
=
np
.
float32
)
for
layer
in
range
(
FLAGS
.
n_layer
)]
for
core
in
range
(
FLAGS
.
num_core_per_host
)
]
saver
=
tf
.
train
.
Saver
()
with
tf
.
Session
(
config
=
tf
.
ConfigProto
(
allow_soft_placement
=
True
))
as
sess
:
sess
.
run
(
tf
.
global_variables_initializer
())
if
FLAGS
.
warm_start_path
is
not
None
:
tf
.
logging
.
info
(
"warm start from {}"
.
format
(
FLAGS
.
warm_start_path
))
saver
.
restore
(
sess
,
FLAGS
.
warm_start_path
)
fetches
=
[
loss
,
tower_new_mems
,
global_step
,
gnorm
,
learning_rate
,
train_op
]
total_loss
,
prev_step
=
0.
,
-
1
while
True
:
feed_dict
=
{}
for
i
in
range
(
FLAGS
.
num_core_per_host
):
for
m
,
m_np
in
zip
(
tower_mems
[
i
],
tower_mems_np
[
i
]):
feed_dict
[
m
]
=
m_np
#改
s_time
=
time
.
time
()
fetched
=
sess
.
run
(
fetches
,
feed_dict
=
feed_dict
)
e_time
=
time
.
time
()
global_step_s
=
1
/
(
e_time
-
s_time
)
tf
.
logging
.
info
(
"global_step/sec : {}"
.
format
(
global_step_s
))
loss_np
,
tower_mems_np
,
curr_step
=
fetched
[:
3
]
total_loss
+=
loss_np
if
curr_step
>
0
and
curr_step
%
FLAGS
.
iterations
==
0
:
curr_loss
=
total_loss
/
(
curr_step
-
prev_step
)
tf
.
logging
.
info
(
"[{}] | gnorm {:.2f} lr {:8.6f} "
"| loss {:.2f} | pplx {:>7.2f}, bpc {:>7.4f}"
.
format
(
curr_step
,
fetched
[
-
3
],
fetched
[
-
2
],
curr_loss
,
math
.
exp
(
curr_loss
),
curr_loss
/
math
.
log
(
2
)))
total_loss
,
prev_step
=
0.
,
curr_step
if
curr_step
>
0
and
curr_step
%
FLAGS
.
save_steps
==
0
:
save_path
=
os
.
path
.
join
(
FLAGS
.
model_dir
,
"model.ckpt"
)
saver
.
save
(
sess
,
save_path
)
tf
.
logging
.
info
(
"Model saved in path: {}"
.
format
(
save_path
))
if
curr_step
==
FLAGS
.
train_steps
:
break
def
evaluate
(
n_token
,
cutoffs
,
ps_device
):
##### Get input function and model function
eval_input_fn
,
eval_record_info
=
data_utils
.
get_input_fn
(
record_info_dir
=
FLAGS
.
record_info_dir
,
split
=
FLAGS
.
eval_split
,
per_host_bsz
=
FLAGS
.
eval_batch_size
,
tgt_len
=
FLAGS
.
tgt_len
,
num_core_per_host
=
FLAGS
.
num_core_per_host
,
num_hosts
=
1
,
use_tpu
=
False
)
num_batch
=
eval_record_info
[
"num_batch"
]
if
FLAGS
.
max_eval_batch
>
0
:
num_batch
=
FLAGS
.
max_eval_batch
tf
.
logging
.
info
(
"num of batches {}"
.
format
(
num_batch
))
##### Create computational graph
eval_set
=
eval_input_fn
({
"batch_size"
:
FLAGS
.
eval_batch_size
,
"data_dir"
:
FLAGS
.
data_dir
})
input_feed
,
label_feed
=
eval_set
.
make_one_shot_iterator
().
get_next
()
inputs
=
tf
.
split
(
input_feed
,
FLAGS
.
num_core_per_host
,
0
)
labels
=
tf
.
split
(
label_feed
,
FLAGS
.
num_core_per_host
,
0
)
per_core_bsz
=
FLAGS
.
eval_batch_size
//
FLAGS
.
num_core_per_host
tower_mems
,
tower_losses
,
tower_new_mems
=
[],
[],
[]
for
i
in
range
(
FLAGS
.
num_core_per_host
):
with
tf
.
device
(
assign_to_gpu
(
i
,
ps_device
)),
\
tf
.
variable_scope
(
tf
.
get_variable_scope
(),
reuse
=
tf
.
AUTO_REUSE
):
mems_i
=
[
tf
.
placeholder
(
tf
.
float32
,
[
FLAGS
.
mem_len
,
per_core_bsz
,
FLAGS
.
d_model
])
for
_
in
range
(
FLAGS
.
n_layer
)]
loss_i
,
new_mems_i
=
single_core_graph
(
n_token
=
n_token
,
cutoffs
=
cutoffs
,
is_training
=
False
,
inp
=
inputs
[
i
],
tgt
=
labels
[
i
],
mems
=
mems_i
)
tower_mems
.
append
(
mems_i
)
tower_losses
.
append
(
loss_i
)
tower_new_mems
.
append
(
new_mems_i
)
## sum losses across towers
if
len
(
tower_losses
)
>
1
:
loss
=
tf
.
add_n
(
tower_losses
)
/
len
(
tower_losses
)
else
:
loss
=
tower_losses
[
0
]
##### Evaluation loop
tower_mems_np
=
[
[
np
.
zeros
([
FLAGS
.
mem_len
,
per_core_bsz
,
FLAGS
.
d_model
],
dtype
=
np
.
float32
)
for
layer
in
range
(
FLAGS
.
n_layer
)]
for
core
in
range
(
FLAGS
.
num_core_per_host
)
]
saver
=
tf
.
train
.
Saver
()
with
tf
.
Session
(
config
=
tf
.
ConfigProto
(
allow_soft_placement
=
True
))
as
sess
:
sess
.
run
(
tf
.
global_variables_initializer
())
if
FLAGS
.
eval_ckpt_path
is
None
:
eval_ckpt_path
=
tf
.
train
.
latest_checkpoint
(
FLAGS
.
model_dir
)
else
:
eval_ckpt_path
=
FLAGS
.
eval_ckpt_path
tf
.
logging
.
info
(
"Evaluate {}"
.
format
(
eval_ckpt_path
))
saver
.
restore
(
sess
,
eval_ckpt_path
)
fetches
=
[
loss
,
tower_new_mems
,
tf
.
size
(
label_feed
)]
format_str
=
" >> processing batch {{:{0}d}}/{{:{0}d}} .."
.
format
(
len
(
str
(
num_batch
)))
total_loss
,
total_cnt
=
0
,
0
for
step
in
range
(
num_batch
):
if
step
%
(
num_batch
//
10
)
==
0
:
tf
.
logging
.
info
(
format_str
.
format
(
step
,
num_batch
))
feed_dict
=
{}
for
i
in
range
(
FLAGS
.
num_core_per_host
):
for
m
,
m_np
in
zip
(
tower_mems
[
i
],
tower_mems_np
[
i
]):
feed_dict
[
m
]
=
m_np
fetched
=
sess
.
run
(
fetches
,
feed_dict
=
feed_dict
)
loss_np
,
tower_mems_np
,
cnt_np
=
fetched
[:
3
]
total_loss
+=
loss_np
*
cnt_np
total_cnt
+=
cnt_np
avg_loss
=
total_loss
/
total_cnt
tf
.
logging
.
info
(
"| loss {:.2f} | pplx {:>7.2f}, bpc {:>7.4f}"
.
format
(
avg_loss
,
math
.
exp
(
avg_loss
),
avg_loss
/
math
.
log
(
2
)))
def
main
(
unused_argv
):
del
unused_argv
# Unused
tf
.
logging
.
set_verbosity
(
tf
.
logging
.
INFO
)
# Get corpus info
corpus_info
=
data_utils
.
get_corpus_info
(
FLAGS
.
corpus_info_path
)
n_token
=
corpus_info
[
"vocab_size"
]
cutoffs
=
corpus_info
[
"cutoffs"
][
1
:
-
1
]
tf
.
logging
.
info
(
"n_token {}"
.
format
(
n_token
))
if
FLAGS
.
do_train
:
train
(
n_token
,
cutoffs
,
"/gpu:0"
)
if
FLAGS
.
do_eval
:
evaluate
(
n_token
,
cutoffs
,
"/gpu:0"
)
if
__name__
==
"__main__"
:
tf
.
app
.
run
()
TensorFlow/NLP/transformer-xl-master/train_gpu.py_old
0 → 100644
View file @
cb8dde1c
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import math
import time
from absl import flags
import absl.logging as _logging # pylint: disable=unused-import
import tensorflow as tf
import model
import data_utils
from gpu_utils import assign_to_gpu, average_grads_and_vars
import numpy as np
# GPU config
flags.DEFINE_integer("num_hosts", default=1,
help="Number of TPU hosts")
flags.DEFINE_integer("num_core_per_host", default=8,
help="Number of cores per host")
# Experiment (data/checkpoint/directory) config
flags.DEFINE_string("data_dir", default="",
help="Path to tf-records directory.")
flags.DEFINE_string("record_info_dir", default="",
help="Path to local directory containing filenames.txt.")
flags.DEFINE_string("corpus_info_path", default="",
help="Path to corpus-info.json file.")
flags.DEFINE_string("model_dir", default=None,
help="Estimator model_dir.")
flags.DEFINE_bool("do_train", default=True,
help="Whether to run training.")
flags.DEFINE_bool("do_eval", default=False,
help="Whether to run eval on the dev set.")
flags.DEFINE_string("eval_ckpt_path", None,
help="Checkpoint path for do_test evaluation."
"If set, model_dir will be ignored."
"If unset, will use the latest ckpt in model_dir.")
flags.DEFINE_string("warm_start_path", None,
help="Checkpoint path for warm start."
"If set, will clear Adam states."
"Note that the new model_dir should be different"
" from warm_start_path.")
# Optimization config
flags.DEFINE_float("learning_rate", default=2.5e-4,
help="Maximum learning rate.")
flags.DEFINE_float("clip", default=0.25,
help="Gradient clipping value.")
# for cosine decay
flags.DEFINE_float("min_lr_ratio", default=0.004,
help="Minimum ratio learning rate.")
flags.DEFINE_integer("warmup_steps", default=0,
help="Number of steps for linear lr warmup.")
# Training config
flags.DEFINE_integer("train_batch_size", default=60,
help="Size of train batch.")
flags.DEFINE_integer("eval_batch_size", default=60,
help="Size of valid batch.")
flags.DEFINE_integer("train_steps", default=100000,
help="Total number of training steps.")
flags.DEFINE_integer("iterations", default=500,
help="Number of iterations per repeat loop.")
flags.DEFINE_integer("save_steps", default=10000,
help="number of steps for model checkpointing.")
# Evaluation config
flags.DEFINE_bool("do_test", default=False,
help="Run on the test set.")
flags.DEFINE_integer("max_eval_batch", default=-1,
help="Set -1 to turn off. Only used in test mode.")
flags.DEFINE_bool("do_eval_only", default=False,
help="Run evaluation only.")
flags.DEFINE_integer("start_eval_steps", default=10000,
help="Which checkpoint to start with in `do_eval_only` mode.")
flags.DEFINE_string("eval_split", "valid",
help="Which data split to evaluate.")
# Model config
flags.DEFINE_integer("tgt_len", default=70,
help="Number of steps to predict")
flags.DEFINE_integer("mem_len", default=70,
help="Number of steps to cache")
flags.DEFINE_bool("same_length", default=False,
help="Same length attention")
flags.DEFINE_integer("clamp_len", default=-1,
help="Clamp length")
flags.DEFINE_integer("n_layer", default=6,
help="Number of layers.")
flags.DEFINE_integer("d_model", default=500,
help="Dimension of the model.")
flags.DEFINE_integer("d_embed", default=500,
help="Dimension of the embeddings.")
flags.DEFINE_integer("n_head", default=10,
help="Number of attention heads.")
flags.DEFINE_integer("d_head", default=50,
help="Dimension of each attention head.")
flags.DEFINE_integer("d_inner", default=1000,
help="Dimension of inner hidden size in positionwise feed-forward.")
flags.DEFINE_float("dropout", default=0.1,
help="Dropout rate.")
flags.DEFINE_float("dropatt", default=0.1,
help="Attention dropout rate.")
flags.DEFINE_bool("untie_r", default=False,
help="untie r_w_bias and r_r_bias")
# Adaptive Softmax / Embedding
flags.DEFINE_bool("tie_weight", default=True,
help="Tie embedding and softmax weight.")
flags.DEFINE_integer("div_val", default=1,
help="Divide the embedding size by this val for each bin")
flags.DEFINE_bool("proj_share_all_but_first", default=False,
help="True to share all but first projs, False not to share.")
flags.DEFINE_bool("proj_same_dim", default=True,
help="Project the bin with the same dimension.")
# Parameter initialization
flags.DEFINE_enum("init", default="normal",
enum_values=["normal", "uniform"],
help="Initialization method.")
flags.DEFINE_float("init_std", default=0.02,
help="Initialization std when init is normal.")
flags.DEFINE_float("proj_init_std", default=0.01,
help="Initialization std for embedding projection.")
flags.DEFINE_float("init_range", default=0.1,
help="Initialization std when init is uniform.")
FLAGS = flags.FLAGS
def get_model_fn(n_token, cutoffs):
def model_fn(inp, tgt, mems, is_training):
inp = tf.transpose(inp, [1, 0])
tgt = tf.transpose(tgt, [1, 0])
if FLAGS.init == "uniform":
initializer = tf.initializers.random_uniform(
minval=-FLAGS.init_range,
maxval=FLAGS.init_range,
seed=None)
elif FLAGS.init == "normal":
initializer = tf.initializers.random_normal(
stddev=FLAGS.init_std,
seed=None)
proj_initializer = tf.initializers.random_normal(
stddev=FLAGS.proj_init_std,
seed=None)
tie_projs = [False for _ in range(len(cutoffs) + 1)]
if FLAGS.proj_share_all_but_first:
for i in range(1, len(tie_projs)):
tie_projs[i] = True
loss, new_mems = model.transformer(
dec_inp=inp,
target=tgt,
mems=mems,
n_token=n_token,
n_layer=FLAGS.n_layer,
d_model=FLAGS.d_model,
d_embed=FLAGS.d_embed,
n_head=FLAGS.n_head,
d_head=FLAGS.d_head,
d_inner=FLAGS.d_inner,
dropout=FLAGS.dropout,
dropatt=FLAGS.dropatt,
initializer=initializer,
proj_initializer=proj_initializer,
is_training=is_training,
mem_len=FLAGS.mem_len,
cutoffs=cutoffs,
div_val=FLAGS.div_val,
tie_projs=tie_projs,
input_perms=None,
target_perms=None,
head_target=None,
same_length=FLAGS.same_length,
clamp_len=FLAGS.clamp_len,
use_tpu=False,
untie_r=FLAGS.untie_r,
proj_same_dim=FLAGS.proj_same_dim)
# number of parameters
num_params = sum([np.prod(v.shape) for v in tf.trainable_variables()])
tf.logging.info('#params: {}'.format(num_params))
# format_str = '{{:<{0}s}}\t{{}}'.format(
# max([len(v.name) for v in tf.trainable_variables()]))
# for v in tf.trainable_variables():
# tf.logging.info(format_str.format(v.name, v.get_shape()))
if is_training:
all_vars = tf.trainable_variables()
grads = tf.gradients(loss, all_vars)
grads_and_vars = list(zip(grads, all_vars))
return loss, new_mems, grads_and_vars
else:
return loss, new_mems
return model_fn
def single_core_graph(n_token, cutoffs, is_training, inp, tgt, mems):
model_fn = get_model_fn(
n_token=n_token,
cutoffs=cutoffs)
model_ret = model_fn(
inp=inp,
tgt=tgt,
mems=mems,
is_training=is_training)
return model_ret
def train(n_token, cutoffs, ps_device):
##### Get input function and model function
train_input_fn, train_record_info = data_utils.get_input_fn(
record_info_dir=FLAGS.record_info_dir,
split="train",
per_host_bsz=FLAGS.train_batch_size,
tgt_len=FLAGS.tgt_len,
num_core_per_host=FLAGS.num_core_per_host,
num_hosts=1,
use_tpu=False)
tf.logging.info("num of batches {}".format(train_record_info["num_batch"]))
##### Create computational graph
train_set = train_input_fn({
"batch_size": FLAGS.train_batch_size,
"data_dir": FLAGS.data_dir})
input_feed, label_feed = train_set.make_one_shot_iterator().get_next()
inputs = tf.split(input_feed, FLAGS.num_core_per_host, 0)
labels = tf.split(label_feed, FLAGS.num_core_per_host, 0)
per_core_bsz = FLAGS.train_batch_size // FLAGS.num_core_per_host
tower_mems, tower_losses, tower_new_mems, tower_grads_and_vars = [], [], [], []
for i in range(FLAGS.num_core_per_host):
reuse = True if i > 0 else None
with tf.device(assign_to_gpu(i, ps_device)), \
tf.variable_scope(tf.get_variable_scope(), reuse=reuse):
mems_i = [tf.placeholder(tf.float32,
[FLAGS.mem_len, per_core_bsz, FLAGS.d_model])
for _ in range(FLAGS.n_layer)]
loss_i, new_mems_i, grads_and_vars_i = single_core_graph(
n_token=n_token,
cutoffs=cutoffs,
is_training=True,
inp=inputs[i],
tgt=labels[i],
mems=mems_i)
tower_mems.append(mems_i)
tower_losses.append(loss_i)
tower_new_mems.append(new_mems_i)
tower_grads_and_vars.append(grads_and_vars_i)
## average losses and gradients across towers
if len(tower_losses) > 1:
loss = tf.add_n(tower_losses) / len(tower_losses)
grads_and_vars = average_grads_and_vars(tower_grads_and_vars)
else:
loss = tower_losses[0]
grads_and_vars = tower_grads_and_vars[0]
grads, all_vars = zip(*grads_and_vars)
## clip gradient
clipped, gnorm = tf.clip_by_global_norm(grads, FLAGS.clip)
grads_and_vars = list(zip(clipped, all_vars))
## configure the optimizer
global_step = tf.train.get_or_create_global_step()
# warmup stage: increase the learning rate linearly
if FLAGS.warmup_steps > 0:
warmup_lr = tf.to_float(global_step) / tf.to_float(FLAGS.warmup_steps) \
* FLAGS.learning_rate
else:
warmup_lr = 0.0
# decay stage: decay the learning rate using the cosine schedule
decay_lr = tf.train.cosine_decay(
FLAGS.learning_rate,
global_step=global_step-FLAGS.warmup_steps,
decay_steps=FLAGS.train_steps-FLAGS.warmup_steps,
alpha=FLAGS.min_lr_ratio)
# choose warmup or decay
learning_rate = tf.where(global_step < FLAGS.warmup_steps,
warmup_lr, decay_lr)
# get the train op
optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate)
train_op = optimizer.apply_gradients(grads_and_vars, global_step)
##### Training loop
tower_mems_np = [
[np.zeros([FLAGS.mem_len, per_core_bsz, FLAGS.d_model], dtype=np.float32)
for layer in range(FLAGS.n_layer)]
for core in range(FLAGS.num_core_per_host)
]
saver = tf.train.Saver()
with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess:
sess.run(tf.global_variables_initializer())
if FLAGS.warm_start_path is not None:
tf.logging.info("warm start from {}".format(FLAGS.warm_start_path))
saver.restore(sess, FLAGS.warm_start_path)
fetches = [loss, tower_new_mems, global_step, gnorm, learning_rate, train_op]
total_loss, prev_step = 0., -1
while True:
feed_dict = {}
for i in range(FLAGS.num_core_per_host):
for m, m_np in zip(tower_mems[i], tower_mems_np[i]):
feed_dict[m] = m_np
fetched = sess.run(fetches, feed_dict=feed_dict)
loss_np, tower_mems_np, curr_step = fetched[:3]
total_loss += loss_np
if curr_step > 0 and curr_step % FLAGS.iterations == 0:
curr_loss = total_loss / (curr_step - prev_step)
tf.logging.info("[{}] | gnorm {:.2f} lr {:8.6f} "
"| loss {:.2f} | pplx {:>7.2f}, bpc {:>7.4f}".format(
curr_step, fetched[-3], fetched[-2],
curr_loss, math.exp(curr_loss), curr_loss / math.log(2)))
total_loss, prev_step = 0., curr_step
if curr_step > 0 and curr_step % FLAGS.save_steps == 0:
save_path = os.path.join(FLAGS.model_dir, "model.ckpt")
saver.save(sess, save_path)
tf.logging.info("Model saved in path: {}".format(save_path))
if curr_step == FLAGS.train_steps:
break
def evaluate(n_token, cutoffs, ps_device):
##### Get input function and model function
eval_input_fn, eval_record_info = data_utils.get_input_fn(
record_info_dir=FLAGS.record_info_dir,
split=FLAGS.eval_split,
per_host_bsz=FLAGS.eval_batch_size,
tgt_len=FLAGS.tgt_len,
num_core_per_host=FLAGS.num_core_per_host,
num_hosts=1,
use_tpu=False)
num_batch = eval_record_info["num_batch"]
if FLAGS.max_eval_batch > 0:
num_batch = FLAGS.max_eval_batch
tf.logging.info("num of batches {}".format(num_batch))
##### Create computational graph
eval_set = eval_input_fn({
"batch_size": FLAGS.eval_batch_size,
"data_dir": FLAGS.data_dir})
input_feed, label_feed = eval_set.make_one_shot_iterator().get_next()
inputs = tf.split(input_feed, FLAGS.num_core_per_host, 0)
labels = tf.split(label_feed, FLAGS.num_core_per_host, 0)
per_core_bsz = FLAGS.eval_batch_size // FLAGS.num_core_per_host
tower_mems, tower_losses, tower_new_mems = [], [], []
for i in range(FLAGS.num_core_per_host):
with tf.device(assign_to_gpu(i, ps_device)), \
tf.variable_scope(tf.get_variable_scope(), reuse=tf.AUTO_REUSE):
mems_i = [tf.placeholder(tf.float32,
[FLAGS.mem_len, per_core_bsz, FLAGS.d_model])
for _ in range(FLAGS.n_layer)]
loss_i, new_mems_i = single_core_graph(
n_token=n_token,
cutoffs=cutoffs,
is_training=False,
inp=inputs[i],
tgt=labels[i],
mems=mems_i)
tower_mems.append(mems_i)
tower_losses.append(loss_i)
tower_new_mems.append(new_mems_i)
## sum losses across towers
if len(tower_losses) > 1:
loss = tf.add_n(tower_losses) / len(tower_losses)
else:
loss = tower_losses[0]
##### Evaluation loop
tower_mems_np = [
[np.zeros([FLAGS.mem_len, per_core_bsz, FLAGS.d_model], dtype=np.float32)
for layer in range(FLAGS.n_layer)]
for core in range(FLAGS.num_core_per_host)
]
saver = tf.train.Saver()
with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess:
sess.run(tf.global_variables_initializer())
if FLAGS.eval_ckpt_path is None:
eval_ckpt_path = tf.train.latest_checkpoint(FLAGS.model_dir)
else:
eval_ckpt_path = FLAGS.eval_ckpt_path
tf.logging.info("Evaluate {}".format(eval_ckpt_path))
saver.restore(sess, eval_ckpt_path)
fetches = [loss, tower_new_mems, tf.size(label_feed)]
format_str = " >> processing batch {{:{0}d}}/{{:{0}d}} ..".format(
len(str(num_batch)))
total_loss, total_cnt = 0, 0
for step in range(num_batch):
if step % (num_batch // 10) == 0:
tf.logging.info(format_str.format(step, num_batch))
feed_dict = {}
for i in range(FLAGS.num_core_per_host):
for m, m_np in zip(tower_mems[i], tower_mems_np[i]):
feed_dict[m] = m_np
fetched = sess.run(fetches, feed_dict=feed_dict)
loss_np, tower_mems_np, cnt_np = fetched[:3]
total_loss += loss_np * cnt_np
total_cnt += cnt_np
avg_loss = total_loss / total_cnt
tf.logging.info("| loss {:.2f} | pplx {:>7.2f}, bpc {:>7.4f}".format(
avg_loss, math.exp(avg_loss), avg_loss / math.log(2)))
def main(unused_argv):
del unused_argv # Unused
tf.logging.set_verbosity(tf.logging.INFO)
# Get corpus info
corpus_info = data_utils.get_corpus_info(FLAGS.corpus_info_path)
n_token = corpus_info["vocab_size"]
cutoffs = corpus_info["cutoffs"][1:-1]
tf.logging.info("n_token {}".format(n_token))
if FLAGS.do_train:
train(n_token, cutoffs, "/gpu:0")
if FLAGS.do_eval:
evaluate(n_token, cutoffs, "/gpu:0")
if __name__ == "__main__":
tf.app.run()
TensorFlow/NLP/transformer-xl-master/vocabulary.py
0 → 100644
View file @
cb8dde1c
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
from
collections
import
Counter
,
OrderedDict
import
numpy
as
np
import
tensorflow
as
tf
from
tensorflow.gfile
import
Open
as
open
from
tensorflow.gfile
import
Exists
as
exists
class
Vocab
(
object
):
def
__init__
(
self
,
special
=
[],
min_freq
=
0
,
max_size
=
None
,
lower_case
=
True
,
delimiter
=
None
,
vocab_file
=
None
):
self
.
counter
=
Counter
()
self
.
special
=
special
self
.
min_freq
=
min_freq
self
.
max_size
=
max_size
self
.
lower_case
=
lower_case
self
.
delimiter
=
delimiter
self
.
vocab_file
=
vocab_file
def
tokenize
(
self
,
line
,
add_eos
=
False
,
add_double_eos
=
False
):
line
=
line
.
strip
()
# convert to lower case
if
self
.
lower_case
:
line
=
line
.
lower
()
# empty delimiter '' will evaluate False
if
self
.
delimiter
==
''
:
symbols
=
line
else
:
symbols
=
line
.
split
(
self
.
delimiter
)
if
add_double_eos
:
# lm1b
return
[
'<S>'
]
+
symbols
+
[
'<S>'
]
elif
add_eos
:
return
symbols
+
[
'<eos>'
]
else
:
return
symbols
def
count_file
(
self
,
path
,
verbose
=
False
,
add_eos
=
False
):
if
verbose
:
print
(
'counting file {} ...'
.
format
(
path
))
assert
exists
(
path
)
sents
=
[]
with
open
(
path
,
'r'
)
as
f
:
for
idx
,
line
in
enumerate
(
f
):
if
verbose
and
idx
>
0
and
idx
%
500000
==
0
:
print
(
' line {}'
.
format
(
idx
))
symbols
=
self
.
tokenize
(
line
,
add_eos
=
add_eos
)
self
.
counter
.
update
(
symbols
)
sents
.
append
(
symbols
)
return
sents
def
count_sents
(
self
,
sents
,
verbose
=
False
):
"""
sents : a list of sentences, each a list of tokenized symbols
"""
if
verbose
:
print
(
'counting {} sents ...'
.
format
(
len
(
sents
)))
for
idx
,
symbols
in
enumerate
(
sents
):
if
verbose
and
idx
>
0
and
idx
%
500000
==
0
:
print
(
' line {}'
.
format
(
idx
))
self
.
counter
.
update
(
symbols
)
def
_build_from_file
(
self
,
vocab_file
):
self
.
idx2sym
=
[]
self
.
sym2idx
=
OrderedDict
()
with
open
(
vocab_file
,
'r'
)
as
f
:
for
line
in
f
:
symb
=
line
.
strip
().
split
()[
0
]
self
.
add_symbol
(
symb
)
self
.
unk_idx
=
self
.
sym2idx
[
'<UNK>'
]
def
build_vocab
(
self
):
if
self
.
vocab_file
:
print
(
'building vocab from {}'
.
format
(
self
.
vocab_file
))
self
.
_build_from_file
(
self
.
vocab_file
)
print
(
'final vocab size {}'
.
format
(
len
(
self
)))
else
:
print
(
'building vocab with min_freq={}, max_size={}'
.
format
(
self
.
min_freq
,
self
.
max_size
))
self
.
idx2sym
=
[]
self
.
sym2idx
=
OrderedDict
()
for
sym
in
self
.
special
:
self
.
add_special
(
sym
)
for
sym
,
cnt
in
self
.
counter
.
most_common
(
self
.
max_size
):
if
cnt
<
self
.
min_freq
:
break
self
.
add_symbol
(
sym
)
print
(
'final vocab size {} from {} unique tokens'
.
format
(
len
(
self
),
len
(
self
.
counter
)))
def
encode_file
(
self
,
path
,
ordered
=
False
,
verbose
=
False
,
add_eos
=
True
,
add_double_eos
=
False
):
if
verbose
:
print
(
'encoding file {} ...'
.
format
(
path
))
assert
exists
(
path
)
encoded
=
[]
with
open
(
path
,
'r'
)
as
f
:
for
idx
,
line
in
enumerate
(
f
):
if
verbose
and
idx
>
0
and
idx
%
500000
==
0
:
print
(
' line {}'
.
format
(
idx
))
symbols
=
self
.
tokenize
(
line
,
add_eos
=
add_eos
,
add_double_eos
=
add_double_eos
)
encoded
.
append
(
self
.
convert_to_nparray
(
symbols
))
if
ordered
:
encoded
=
np
.
concatenate
(
encoded
)
return
encoded
def
encode_sents
(
self
,
sents
,
ordered
=
False
,
verbose
=
False
):
if
verbose
:
print
(
'encoding {} sents ...'
.
format
(
len
(
sents
)))
encoded
=
[]
for
idx
,
symbols
in
enumerate
(
sents
):
if
verbose
and
idx
>
0
and
idx
%
500000
==
0
:
print
(
' line {}'
.
format
(
idx
))
encoded
.
append
(
self
.
convert_to_nparray
(
symbols
))
if
ordered
:
encoded
=
np
.
concatenate
(
encoded
)
return
encoded
def
add_special
(
self
,
sym
):
if
sym
not
in
self
.
sym2idx
:
self
.
idx2sym
.
append
(
sym
)
self
.
sym2idx
[
sym
]
=
len
(
self
.
idx2sym
)
-
1
setattr
(
self
,
'{}_idx'
.
format
(
sym
.
strip
(
'<>'
)),
self
.
sym2idx
[
sym
])
def
add_symbol
(
self
,
sym
):
if
sym
not
in
self
.
sym2idx
:
self
.
idx2sym
.
append
(
sym
)
self
.
sym2idx
[
sym
]
=
len
(
self
.
idx2sym
)
-
1
def
get_sym
(
self
,
idx
):
assert
0
<=
idx
<
len
(
self
),
'Index {} out of range'
.
format
(
idx
)
return
self
.
idx2sym
[
idx
]
def
get_idx
(
self
,
sym
):
if
sym
in
self
.
sym2idx
:
return
self
.
sym2idx
[
sym
]
else
:
assert
hasattr
(
self
,
'unk_idx'
)
return
self
.
sym2idx
.
get
(
sym
,
self
.
unk_idx
)
def
get_symbols
(
self
,
indices
):
return
[
self
.
get_sym
(
idx
)
for
idx
in
indices
]
def
get_indices
(
self
,
symbols
):
return
[
self
.
get_idx
(
sym
)
for
sym
in
symbols
]
def
convert_to_nparray
(
self
,
symbols
):
nparray
=
np
.
array
(
self
.
get_indices
(
symbols
),
dtype
=
np
.
int64
)
return
nparray
def
convert_to_sent
(
self
,
indices
,
exclude
=
None
):
if
exclude
is
None
:
return
' '
.
join
([
self
.
get_sym
(
idx
)
for
idx
in
indices
])
else
:
return
' '
.
join
([
self
.
get_sym
(
idx
)
for
idx
in
indices
if
idx
not
in
exclude
])
def
__len__
(
self
):
return
len
(
self
.
idx2sym
)
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment