Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
c705568b
Commit
c705568b
authored
Aug 29, 2017
by
Yanping Huang
Committed by
Neal Wu
Aug 29, 2017
Browse files
Add different rnn implementation modes to ptb tutorial (#2276)
parent
7e9e15ad
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
302 additions
and
67 deletions
+302
-67
tutorials/rnn/ptb/BUILD
tutorials/rnn/ptb/BUILD
+9
-1
tutorials/rnn/ptb/__init__.py
tutorials/rnn/ptb/__init__.py
+1
-0
tutorials/rnn/ptb/ptb_word_lm.py
tutorials/rnn/ptb/ptb_word_lm.py
+196
-66
tutorials/rnn/ptb/util.py
tutorials/rnn/ptb/util.py
+96
-0
No files found.
tutorials/rnn/ptb/BUILD
View file @
c705568b
...
...
@@ -36,6 +36,13 @@ py_test(
],
)
py_library
(
name
=
"util"
,
srcs
=
[
"util.py"
],
srcs_version
=
"PY2AND3"
,
deps
=
[
"//tensorflow:tensorflow_py"
],
)
py_binary
(
name
=
"ptb_word_lm"
,
srcs
=
[
...
...
@@ -44,7 +51,8 @@ py_binary(
srcs_version
=
"PY2AND3"
,
deps
=
[
":reader"
,
"//tensorflow:tensorflow_py"
,
":util"
,
"//tensorflow:tensorflow_py,
],
)
...
...
tutorials/rnn/ptb/__init__.py
View file @
c705568b
...
...
@@ -19,3 +19,4 @@ from __future__ import division
from
__future__
import
print_function
import
reader
import
util
tutorials/rnn/ptb/ptb_word_lm.py
View file @
c705568b
...
...
@@ -40,6 +40,9 @@ The hyperparameters used in the model:
- keep_prob - the probability of keeping weights in the dropout layer
- lr_decay - the decay of the learning rate for each epoch after "max_epoch"
- batch_size - the batch size
- rnn_mode - the low level implementation of lstm cell: one of CUDNN,
BASIC, or BLOCK, representing cudnn_lstm, basic_lstm, and
lstm_block_cell classes.
The data required for this example is in the data/ dir of the
PTB dataset from Tomas Mikolov's webpage:
...
...
@@ -56,13 +59,15 @@ from __future__ import absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
inspect
import
time
import
numpy
as
np
import
tensorflow
as
tf
import
reader
import
util
from
tensorflow.python.client
import
device_lib
flags
=
tf
.
flags
logging
=
tf
.
logging
...
...
@@ -76,8 +81,18 @@ flags.DEFINE_string("save_path", None,
"Model output directory."
)
flags
.
DEFINE_bool
(
"use_fp16"
,
False
,
"Train using 16-bit floats instead of 32bit floats"
)
flags
.
DEFINE_integer
(
"num_gpus"
,
1
,
"If larger than 1, Grappler AutoParallel optimizer "
"will create multiple training replicas with each GPU "
"running one replica."
)
flags
.
DEFINE_string
(
"rnn_mode"
,
None
,
"The low level implementation of lstm cell: one of CUDNN, "
"BASIC, and BLOCK, representing cudnn_lstm, basic_lstm, "
"and lstm_block_cell classes."
)
FLAGS
=
flags
.
FLAGS
BASIC
=
"basic"
CUDNN
=
"cudnn"
BLOCK
=
"block"
def
data_type
():
...
...
@@ -99,39 +114,15 @@ class PTBModel(object):
"""The PTB model."""
def
__init__
(
self
,
is_training
,
config
,
input_
):
self
.
_is_training
=
is_training
self
.
_input
=
input_
batch_size
=
input_
.
batch_size
num_steps
=
input_
.
num_steps
self
.
_rnn_params
=
None
self
.
_cell
=
None
self
.
batch_size
=
input_
.
batch_size
self
.
num_steps
=
input_
.
num_steps
size
=
config
.
hidden_size
vocab_size
=
config
.
vocab_size
# Slightly better results can be obtained with forget gate biases
# initialized to 1 but the hyperparameters of the model would need to be
# different than reported in the paper.
def
lstm_cell
():
# With the latest TensorFlow source code (as of Mar 27, 2017),
# the BasicLSTMCell will need a reuse parameter which is unfortunately not
# defined in TensorFlow 1.0. To maintain backwards compatibility, we add
# an argument check here:
if
'reuse'
in
inspect
.
getargspec
(
tf
.
contrib
.
rnn
.
BasicLSTMCell
.
__init__
).
args
:
return
tf
.
contrib
.
rnn
.
BasicLSTMCell
(
size
,
forget_bias
=
0.0
,
state_is_tuple
=
True
,
reuse
=
tf
.
get_variable_scope
().
reuse
)
else
:
return
tf
.
contrib
.
rnn
.
BasicLSTMCell
(
size
,
forget_bias
=
0.0
,
state_is_tuple
=
True
)
attn_cell
=
lstm_cell
if
is_training
and
config
.
keep_prob
<
1
:
def
attn_cell
():
return
tf
.
contrib
.
rnn
.
DropoutWrapper
(
lstm_cell
(),
output_keep_prob
=
config
.
keep_prob
)
cell
=
tf
.
contrib
.
rnn
.
MultiRNNCell
(
[
attn_cell
()
for
_
in
range
(
config
.
num_layers
)],
state_is_tuple
=
True
)
self
.
_initial_state
=
cell
.
zero_state
(
batch_size
,
data_type
())
with
tf
.
device
(
"/cpu:0"
):
embedding
=
tf
.
get_variable
(
"embedding"
,
[
vocab_size
,
size
],
dtype
=
data_type
())
...
...
@@ -140,43 +131,25 @@ class PTBModel(object):
if
is_training
and
config
.
keep_prob
<
1
:
inputs
=
tf
.
nn
.
dropout
(
inputs
,
config
.
keep_prob
)
# Simplified version of models/tutorials/rnn/rnn.py's rnn().
# This builds an unrolled LSTM for tutorial purposes only.
# In general, use the rnn() or state_saving_rnn() from rnn.py.
#
# The alternative version of the code below is:
#
# inputs = tf.unstack(inputs, num=num_steps, axis=1)
# outputs, state = tf.contrib.rnn.static_rnn(
# cell, inputs, initial_state=self._initial_state)
outputs
=
[]
state
=
self
.
_initial_state
with
tf
.
variable_scope
(
"RNN"
):
for
time_step
in
range
(
num_steps
):
if
time_step
>
0
:
tf
.
get_variable_scope
().
reuse_variables
()
(
cell_output
,
state
)
=
cell
(
inputs
[:,
time_step
,
:],
state
)
outputs
.
append
(
cell_output
)
output
,
state
=
self
.
_build_rnn_graph
(
inputs
,
config
,
is_training
)
output
=
tf
.
reshape
(
tf
.
stack
(
axis
=
1
,
values
=
outputs
),
[
-
1
,
size
])
softmax_w
=
tf
.
get_variable
(
"softmax_w"
,
[
size
,
vocab_size
],
dtype
=
data_type
())
softmax_b
=
tf
.
get_variable
(
"softmax_b"
,
[
vocab_size
],
dtype
=
data_type
())
logits
=
tf
.
matmul
(
output
,
softmax_w
)
+
softmax_b
logits
=
tf
.
nn
.
xw_plus_b
(
output
,
softmax_w
,
softmax_b
)
# Reshape logits to be a 3-D tensor for sequence loss
logits
=
tf
.
reshape
(
logits
,
[
self
.
batch_size
,
self
.
num_steps
,
vocab_size
])
# Reshape logits to be 3-D tensor for sequence loss
logits
=
tf
.
reshape
(
logits
,
[
batch_size
,
num_steps
,
vocab_size
])
# use the contrib sequence loss and average over the batches
# Use the contrib sequence loss and average over the batches
loss
=
tf
.
contrib
.
seq2seq
.
sequence_loss
(
logits
,
input_
.
targets
,
tf
.
ones
([
batch_size
,
num_steps
],
dtype
=
data_type
()),
tf
.
ones
([
self
.
batch_size
,
self
.
num_steps
],
dtype
=
data_type
()),
average_across_timesteps
=
False
,
average_across_batch
=
True
)
average_across_batch
=
True
)
#
u
pdate the cost
variables
self
.
_cost
=
cost
=
tf
.
reduce_sum
(
loss
)
#
U
pdate the cost
self
.
_cost
=
tf
.
reduce_sum
(
loss
)
self
.
_final_state
=
state
if
not
is_training
:
...
...
@@ -184,7 +157,7 @@ class PTBModel(object):
self
.
_lr
=
tf
.
Variable
(
0.0
,
trainable
=
False
)
tvars
=
tf
.
trainable_variables
()
grads
,
_
=
tf
.
clip_by_global_norm
(
tf
.
gradients
(
cost
,
tvars
),
grads
,
_
=
tf
.
clip_by_global_norm
(
tf
.
gradients
(
self
.
_
cost
,
tvars
),
config
.
max_grad_norm
)
optimizer
=
tf
.
train
.
GradientDescentOptimizer
(
self
.
_lr
)
self
.
_train_op
=
optimizer
.
apply_gradients
(
...
...
@@ -195,9 +168,120 @@ class PTBModel(object):
tf
.
float32
,
shape
=
[],
name
=
"new_learning_rate"
)
self
.
_lr_update
=
tf
.
assign
(
self
.
_lr
,
self
.
_new_lr
)
def
_build_rnn_graph
(
self
,
inputs
,
config
,
is_training
):
if
config
.
rnn_mode
==
CUDNN
:
return
self
.
_build_rnn_graph_cudnn
(
inputs
,
config
,
is_training
)
else
:
return
self
.
_build_rnn_graph_lstm
(
inputs
,
config
,
is_training
)
def
_build_rnn_graph_cudnn
(
self
,
inputs
,
config
,
is_training
):
"""Build the inference graph using CUDNN cell."""
inputs
=
tf
.
transpose
(
inputs
,
[
1
,
0
,
2
])
self
.
_cell
=
tf
.
contrib
.
cudnn_rnn
.
CudnnLSTM
(
num_layers
=
config
.
num_layers
,
num_units
=
config
.
hidden_size
,
input_size
=
config
.
hidden_size
,
dropout
=
1
-
config
.
keep_prob
if
is_training
else
0
)
params_size_t
=
self
.
_cell
.
params_size
()
self
.
_rnn_params
=
tf
.
get_variable
(
"lstm_params"
,
initializer
=
tf
.
random_uniform
(
[
params_size_t
],
-
config
.
init_scale
,
config
.
init_scale
),
validate_shape
=
False
)
c
=
tf
.
zeros
([
config
.
num_layers
,
self
.
batch_size
,
config
.
hidden_size
],
tf
.
float32
)
h
=
tf
.
zeros
([
config
.
num_layers
,
self
.
batch_size
,
config
.
hidden_size
],
tf
.
float32
)
self
.
_initial_state
=
(
tf
.
contrib
.
rnn
.
LSTMStateTuple
(
h
=
h
,
c
=
c
),)
outputs
,
h
,
c
=
self
.
_cell
(
inputs
,
h
,
c
,
self
.
_rnn_params
,
is_training
)
outputs
=
tf
.
transpose
(
outputs
,
[
1
,
0
,
2
])
outputs
=
tf
.
reshape
(
outputs
,
[
-
1
,
config
.
hidden_size
])
return
outputs
,
(
tf
.
contrib
.
rnn
.
LSTMStateTuple
(
h
=
h
,
c
=
c
),)
def
_get_lstm_cell
(
self
,
config
,
is_training
):
if
config
.
rnn_mode
==
BASIC
:
return
tf
.
contrib
.
rnn
.
BasicLSTMCell
(
config
.
hidden_size
,
forget_bias
=
0.0
,
state_is_tuple
=
True
,
reuse
=
not
is_training
)
if
config
.
rnn_mode
==
BLOCK
:
return
tf
.
contrib
.
rnn
.
LSTMBlockCell
(
config
.
hidden_size
,
forget_bias
=
0.0
)
raise
ValueError
(
"rnn_mode %s not supported"
%
config
.
rnn_mode
)
def
_build_rnn_graph_lstm
(
self
,
inputs
,
config
,
is_training
):
"""Build the inference graph using canonical LSTM cells."""
# Slightly better results can be obtained with forget gate biases
# initialized to 1 but the hyperparameters of the model would need to be
# different than reported in the paper.
cell
=
self
.
_get_lstm_cell
(
config
,
is_training
)
if
is_training
and
config
.
keep_prob
<
1
:
cell
=
tf
.
contrib
.
rnn
.
DropoutWrapper
(
cell
,
output_keep_prob
=
config
.
keep_prob
)
cell
=
tf
.
contrib
.
rnn
.
MultiRNNCell
(
[
cell
for
_
in
range
(
config
.
num_layers
)],
state_is_tuple
=
True
)
self
.
_initial_state
=
cell
.
zero_state
(
config
.
batch_size
,
data_type
())
state
=
self
.
_initial_state
# Simplified version of tensorflow_models/tutorials/rnn/rnn.py's rnn().
# This builds an unrolled LSTM for tutorial purposes only.
# In general, use the rnn() or state_saving_rnn() from rnn.py.
#
# The alternative version of the code below is:
#
# inputs = tf.unstack(inputs, num=num_steps, axis=1)
# outputs, state = tf.contrib.rnn.static_rnn(cell, inputs,
# initial_state=self._initial_state)
outputs
=
[]
with
tf
.
variable_scope
(
"RNN"
):
for
time_step
in
range
(
self
.
num_steps
):
if
time_step
>
0
:
tf
.
get_variable_scope
().
reuse_variables
()
(
cell_output
,
state
)
=
cell
(
inputs
[:,
time_step
,
:],
state
)
outputs
.
append
(
cell_output
)
output
=
tf
.
reshape
(
tf
.
concat
(
outputs
,
1
),
[
-
1
,
config
.
hidden_size
])
return
output
,
state
def
assign_lr
(
self
,
session
,
lr_value
):
session
.
run
(
self
.
_lr_update
,
feed_dict
=
{
self
.
_new_lr
:
lr_value
})
def
export_ops
(
self
,
name
):
"""Exports ops to collections."""
self
.
_name
=
name
ops
=
{
util
.
with_prefix
(
self
.
_name
,
"cost"
):
self
.
_cost
}
if
self
.
_is_training
:
ops
.
update
(
lr
=
self
.
_lr
,
new_lr
=
self
.
_new_lr
,
lr_update
=
self
.
_lr_update
)
if
self
.
_rnn_params
:
ops
.
update
(
rnn_params
=
self
.
_rnn_params
)
for
name
,
op
in
ops
.
iteritems
():
tf
.
add_to_collection
(
name
,
op
)
self
.
_initial_state_name
=
util
.
with_prefix
(
self
.
_name
,
"initial"
)
self
.
_final_state_name
=
util
.
with_prefix
(
self
.
_name
,
"final"
)
util
.
export_state_tuples
(
self
.
_initial_state
,
self
.
_initial_state_name
)
util
.
export_state_tuples
(
self
.
_final_state
,
self
.
_final_state_name
)
def
import_ops
(
self
):
"""Imports ops from collections."""
if
self
.
_is_training
:
self
.
_train_op
=
tf
.
get_collection_ref
(
"train_op"
)[
0
]
self
.
_lr
=
tf
.
get_collection_ref
(
"lr"
)[
0
]
self
.
_new_lr
=
tf
.
get_collection_ref
(
"new_lr"
)[
0
]
self
.
_lr_update
=
tf
.
get_collection_ref
(
"lr_update"
)[
0
]
rnn_params
=
tf
.
get_collection_ref
(
"rnn_params"
)
if
self
.
_cell
and
rnn_params
:
params_saveable
=
tf
.
contrib
.
cudnn_rnn
.
RNNParamsSaveable
(
self
.
_cell
,
self
.
_cell
.
params_to_canonical
,
self
.
_cell
.
canonical_to_params
,
rnn_params
,
base_variable_scope
=
"Model/RNN"
)
tf
.
add_to_collection
(
tf
.
GraphKeys
.
SAVEABLE_OBJECTS
,
params_saveable
)
self
.
_cost
=
tf
.
get_collection_ref
(
util
.
with_prefix
(
self
.
_name
,
"cost"
))[
0
]
num_replicas
=
FLAGS
.
num_gpus
if
self
.
_name
==
"Train"
else
1
self
.
_initial_state
=
util
.
import_state_tuples
(
self
.
_initial_state
,
self
.
_initial_state_name
,
num_replicas
)
self
.
_final_state
=
util
.
import_state_tuples
(
self
.
_final_state
,
self
.
_final_state_name
,
num_replicas
)
@
property
def
input
(
self
):
return
self
.
_input
...
...
@@ -222,6 +306,14 @@ class PTBModel(object):
def
train_op
(
self
):
return
self
.
_train_op
@
property
def
initial_state_name
(
self
):
return
self
.
_initial_state_name
@
property
def
final_state_name
(
self
):
return
self
.
_final_state_name
class
SmallConfig
(
object
):
"""Small config."""
...
...
@@ -237,6 +329,7 @@ class SmallConfig(object):
lr_decay
=
0.5
batch_size
=
20
vocab_size
=
10000
rnn_mode
=
CUDNN
class
MediumConfig
(
object
):
...
...
@@ -253,6 +346,7 @@ class MediumConfig(object):
lr_decay
=
0.8
batch_size
=
20
vocab_size
=
10000
rnn_mode
=
BLOCK
class
LargeConfig
(
object
):
...
...
@@ -269,6 +363,7 @@ class LargeConfig(object):
lr_decay
=
1
/
1.15
batch_size
=
20
vocab_size
=
10000
rnn_mode
=
BLOCK
class
TestConfig
(
object
):
...
...
@@ -285,6 +380,7 @@ class TestConfig(object):
lr_decay
=
0.5
batch_size
=
20
vocab_size
=
10000
rnn_mode
=
BLOCK
def
run_epoch
(
session
,
model
,
eval_op
=
None
,
verbose
=
False
):
...
...
@@ -317,27 +413,43 @@ def run_epoch(session, model, eval_op=None, verbose=False):
if
verbose
and
step
%
(
model
.
input
.
epoch_size
//
10
)
==
10
:
print
(
"%.3f perplexity: %.3f speed: %.0f wps"
%
(
step
*
1.0
/
model
.
input
.
epoch_size
,
np
.
exp
(
costs
/
iters
),
iters
*
model
.
input
.
batch_size
/
(
time
.
time
()
-
start_time
)))
iters
*
model
.
input
.
batch_size
*
max
(
1
,
FLAGS
.
num_gpus
)
/
(
time
.
time
()
-
start_time
)))
return
np
.
exp
(
costs
/
iters
)
def
get_config
():
"""Get model config."""
config
=
None
if
FLAGS
.
model
==
"small"
:
return
SmallConfig
()
config
=
SmallConfig
()
elif
FLAGS
.
model
==
"medium"
:
return
MediumConfig
()
config
=
MediumConfig
()
elif
FLAGS
.
model
==
"large"
:
return
LargeConfig
()
config
=
LargeConfig
()
elif
FLAGS
.
model
==
"test"
:
return
TestConfig
()
config
=
TestConfig
()
else
:
raise
ValueError
(
"Invalid model: %s"
,
FLAGS
.
model
)
if
FLAGS
.
rnn_mode
:
config
.
rnn_mode
=
FLAGS
.
rnn_mode
if
FLAGS
.
num_gpus
!=
1
or
tf
.
__version__
<
"1.3.0"
:
config
.
rnn_mode
=
BASIC
return
config
def
main
(
_
):
if
not
FLAGS
.
data_path
:
raise
ValueError
(
"Must set --data_path to PTB data directory"
)
gpus
=
[
x
.
name
for
x
in
device_lib
.
list_local_devices
()
if
x
.
device_type
==
"GPU"
]
if
FLAGS
.
num_gpus
>
len
(
gpus
):
raise
ValueError
(
"Your machine has only %d gpus "
"which is less than the requested --num_gpus=%d."
%
(
len
(
gpus
),
FLAGS
.
num_gpus
))
raw_data
=
reader
.
ptb_raw_data
(
FLAGS
.
data_path
)
train_data
,
valid_data
,
test_data
,
_
=
raw_data
...
...
@@ -365,13 +477,31 @@ def main(_):
tf
.
summary
.
scalar
(
"Validation Loss"
,
mvalid
.
cost
)
with
tf
.
name_scope
(
"Test"
):
test_input
=
PTBInput
(
config
=
eval_config
,
data
=
test_data
,
name
=
"TestInput"
)
test_input
=
PTBInput
(
config
=
eval_config
,
data
=
test_data
,
name
=
"TestInput"
)
with
tf
.
variable_scope
(
"Model"
,
reuse
=
True
,
initializer
=
initializer
):
mtest
=
PTBModel
(
is_training
=
False
,
config
=
eval_config
,
input_
=
test_input
)
models
=
{
"Train"
:
m
,
"Valid"
:
mvalid
,
"Test"
:
mtest
}
for
name
,
model
in
models
.
iteritems
():
model
.
export_ops
(
name
)
metagraph
=
tf
.
train
.
export_meta_graph
()
if
tf
.
__version__
<
"1.1.0"
and
FLAGS
.
num_gpus
>
1
:
raise
ValueError
(
"num_gpus > 1 is not supported for TensorFlow versions "
"below 1.1.0"
)
soft_placement
=
False
if
FLAGS
.
num_gpus
>
1
:
soft_placement
=
True
util
.
auto_parallel
(
metagraph
,
m
)
with
tf
.
Graph
().
as_default
():
tf
.
train
.
import_meta_graph
(
metagraph
)
for
model
in
models
.
values
():
model
.
import_ops
()
sv
=
tf
.
train
.
Supervisor
(
logdir
=
FLAGS
.
save_path
)
with
sv
.
managed_session
()
as
session
:
config_proto
=
tf
.
ConfigProto
(
allow_soft_placement
=
soft_placement
)
with
sv
.
managed_session
(
config
=
config_proto
)
as
session
:
for
i
in
range
(
config
.
max_max_epoch
):
lr_decay
=
config
.
lr_decay
**
max
(
i
+
1
-
config
.
max_epoch
,
0.0
)
m
.
assign_lr
(
session
,
config
.
learning_rate
*
lr_decay
)
...
...
tutorials/rnn/ptb/util.py
0 → 100644
View file @
c705568b
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utilities for Grappler autoparallel optimizer."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
tensorflow
as
tf
from
tensorflow.core.framework
import
variable_pb2
from
tensorflow.core.protobuf
import
rewriter_config_pb2
FLAGS
=
tf
.
flags
.
FLAGS
def
export_state_tuples
(
state_tuples
,
name
):
for
state_tuple
in
state_tuples
:
tf
.
add_to_collection
(
name
,
state_tuple
.
c
)
tf
.
add_to_collection
(
name
,
state_tuple
.
h
)
def
import_state_tuples
(
state_tuples
,
name
,
num_replicas
):
restored
=
[]
for
i
in
range
(
len
(
state_tuples
)
*
num_replicas
):
c
=
tf
.
get_collection_ref
(
name
)[
2
*
i
+
0
]
h
=
tf
.
get_collection_ref
(
name
)[
2
*
i
+
1
]
restored
.
append
(
tf
.
contrib
.
rnn
.
LSTMStateTuple
(
c
,
h
))
return
tuple
(
restored
)
def
with_prefix
(
prefix
,
name
):
"""Adds prefix to name."""
return
"/"
.
join
((
prefix
,
name
))
def
with_autoparallel_prefix
(
replica_id
,
name
):
return
with_prefix
(
"AutoParallel-Replica-%d"
%
replica_id
,
name
)
class
UpdateCollection
(
object
):
"""Update collection info in MetaGraphDef for AutoParallel optimizer."""
def
__init__
(
self
,
metagraph
,
model
):
self
.
_metagraph
=
metagraph
self
.
replicate_states
(
model
.
initial_state_name
)
self
.
replicate_states
(
model
.
final_state_name
)
self
.
update_snapshot_name
(
"variables"
)
self
.
update_snapshot_name
(
"trainable_variables"
)
def
update_snapshot_name
(
self
,
var_coll_name
):
var_list
=
self
.
_metagraph
.
collection_def
[
var_coll_name
]
for
i
,
value
in
enumerate
(
var_list
.
bytes_list
.
value
):
var_def
=
variable_pb2
.
VariableDef
()
var_def
.
ParseFromString
(
value
)
# Somehow node Model/global_step/read doesn't have any fanout and seems to
# be only used for snapshot; this is different from all other variables.
if
var_def
.
snapshot_name
!=
"Model/global_step/read:0"
:
var_def
.
snapshot_name
=
with_autoparallel_prefix
(
0
,
var_def
.
snapshot_name
)
value
=
var_def
.
SerializeToString
()
var_list
.
bytes_list
.
value
[
i
]
=
value
def
replicate_states
(
self
,
state_coll_name
):
state_list
=
self
.
_metagraph
.
collection_def
[
state_coll_name
]
num_states
=
len
(
state_list
.
node_list
.
value
)
for
replica_id
in
range
(
1
,
FLAGS
.
num_gpus
):
for
i
in
range
(
num_states
):
state_list
.
node_list
.
value
.
append
(
state_list
.
node_list
.
value
[
i
])
for
replica_id
in
range
(
FLAGS
.
num_gpus
):
for
i
in
range
(
num_states
):
index
=
replica_id
*
num_states
+
i
state_list
.
node_list
.
value
[
index
]
=
with_autoparallel_prefix
(
replica_id
,
state_list
.
node_list
.
value
[
index
])
def
auto_parallel
(
metagraph
,
model
):
from
google3.third_party.tensorflow.python.grappler
import
tf_optimizer
rewriter_config
=
rewriter_config_pb2
.
RewriterConfig
()
rewriter_config
.
optimizers
.
append
(
"autoparallel"
)
rewriter_config
.
auto_parallel
.
enable
=
True
rewriter_config
.
auto_parallel
.
num_replicas
=
FLAGS
.
num_gpus
optimized_graph
=
tf_optimizer
.
OptimizeGraph
(
rewriter_config
,
metagraph
)
metagraph
.
graph_def
.
CopyFrom
(
optimized_graph
)
UpdateCollection
(
metagraph
,
model
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment