Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
af6527c9
Unverified
Commit
af6527c9
authored
Feb 19, 2018
by
Andrew M Dai
Committed by
GitHub
Feb 19, 2018
Browse files
Merge pull request #3402 from a-dai/master
Merging in improvements and fixes to adversarial_text
parents
f51da4bb
9a9e4228
Changes
14
Hide whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
94 additions
and
68 deletions
+94
-68
CODEOWNERS
CODEOWNERS
+1
-1
research/adversarial_text/BUILD
research/adversarial_text/BUILD
+6
-3
research/adversarial_text/README.md
research/adversarial_text/README.md
+1
-0
research/adversarial_text/adversarial_losses.py
research/adversarial_text/adversarial_losses.py
+18
-9
research/adversarial_text/data/data_utils.py
research/adversarial_text/data/data_utils.py
+1
-1
research/adversarial_text/data/document_generators.py
research/adversarial_text/data/document_generators.py
+1
-1
research/adversarial_text/evaluate.py
research/adversarial_text/evaluate.py
+4
-2
research/adversarial_text/graphs.py
research/adversarial_text/graphs.py
+33
-10
research/adversarial_text/graphs_test.py
research/adversarial_text/graphs_test.py
+1
-1
research/adversarial_text/inputs.py
research/adversarial_text/inputs.py
+0
-11
research/adversarial_text/layers.py
research/adversarial_text/layers.py
+21
-21
research/adversarial_text/pretrain.py
research/adversarial_text/pretrain.py
+2
-2
research/adversarial_text/train_classifier.py
research/adversarial_text/train_classifier.py
+2
-2
research/adversarial_text/train_utils.py
research/adversarial_text/train_utils.py
+3
-4
No files found.
CODEOWNERS
View file @
af6527c9
/official/ @nealwu @k-w-w @karmel
/research/adversarial_crypto/ @dave-andersen
/research/adversarial_text/ @rsepassi
/research/adversarial_text/ @rsepassi
@a-dai
/research/adv_imagenet_models/ @AlexeyKurakin
/research/attention_ocr/ @alexgorban
/research/audioset/ @plakal @dpwe
...
...
research/adversarial_text/BUILD
View file @
af6527c9
licenses
([
"notice"
])
# Apache 2.0
exports_files
([
"LICENSE"
])
# Binaries
# ==============================================================================
py_binary
(
...
...
@@ -8,7 +10,7 @@ py_binary(
deps
=
[
":graphs"
,
# google3 file dep,
# tensorflow dep,
# tensorflow
internal
dep,
],
)
...
...
@@ -19,7 +21,7 @@ py_binary(
":graphs"
,
":train_utils"
,
# google3 file dep,
# tensorflow dep,
# tensorflow
internal
dep,
],
)
...
...
@@ -32,7 +34,8 @@ py_binary(
":graphs"
,
":train_utils"
,
# google3 file dep,
# tensorflow dep,
# tensorflow internal gpu deps
# tensorflow internal dep,
],
)
...
...
research/adversarial_text/README.md
View file @
af6527c9
...
...
@@ -154,3 +154,4 @@ control which dataset is processed and how.
## Contact for Issues
*
Ryan Sepassi, @rsepassi
*
Andrew M. Dai, @a-dai
research/adversarial_text/adversarial_losses.py
View file @
af6527c9
...
...
@@ -16,7 +16,6 @@
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
from
six.moves
import
xrange
# pylint: disable=redefined-builtin
# Dependency imports
...
...
@@ -39,6 +38,8 @@ flags.DEFINE_float('small_constant_for_finite_diff', 1e-1,
# Parameters for building the graph
flags
.
DEFINE_string
(
'adv_training_method'
,
None
,
'The flag which specifies training method. '
'"" : non-adversarial training (e.g. for running the '
' semi-supervised sequence learning model) '
'"rp" : random perturbation training '
'"at" : adversarial training '
'"vat" : virtual adversarial training '
...
...
@@ -74,7 +75,7 @@ def virtual_adversarial_loss(logits, embedded, inputs,
between the new logits and the original logits.
Args:
logits:
2
-D float Tensor, [num_timesteps
*batch_size
, m], where m=1 if
logits:
3
-D float Tensor, [
batch_size,
num_timesteps, m], where m=1 if
num_classes=2, otherwise m=num_classes.
embedded: 3-D float Tensor, [batch_size, num_timesteps, embedding_dim].
inputs: VatxtInput.
...
...
@@ -90,6 +91,9 @@ def virtual_adversarial_loss(logits, embedded, inputs,
# Only care about the KL divergence on the final timestep.
weights
=
inputs
.
eos_weights
assert
weights
is
not
None
if
FLAGS
.
single_label
:
indices
=
tf
.
stack
([
tf
.
range
(
FLAGS
.
batch_size
),
inputs
.
length
-
1
],
1
)
weights
=
tf
.
expand_dims
(
tf
.
gather_nd
(
inputs
.
eos_weights
,
indices
),
1
)
# Initialize perturbation with random noise.
# shape(embedded) = (batch_size, num_timesteps, embedding_dim)
...
...
@@ -102,6 +106,7 @@ def virtual_adversarial_loss(logits, embedded, inputs,
for
_
in
xrange
(
FLAGS
.
num_power_iteration
):
d
=
_scale_l2
(
_mask_by_length
(
d
,
inputs
.
length
),
FLAGS
.
small_constant_for_finite_diff
)
d_logits
=
logits_from_embedding_fn
(
embedded
+
d
)
kl
=
_kl_divergence_with_logits
(
logits
,
d_logits
,
weights
)
d
,
=
tf
.
gradients
(
...
...
@@ -142,6 +147,9 @@ def virtual_adversarial_loss_bidir(logits, embedded, inputs,
logits
=
tf
.
stop_gradient
(
logits
)
f_inputs
,
_
=
inputs
weights
=
f_inputs
.
eos_weights
if
FLAGS
.
single_label
:
indices
=
tf
.
stack
([
tf
.
range
(
FLAGS
.
batch_size
),
f_inputs
.
length
-
1
],
1
)
weights
=
tf
.
expand_dims
(
tf
.
gather_nd
(
f_inputs
.
eos_weights
,
indices
),
1
)
assert
weights
is
not
None
perturbs
=
[
...
...
@@ -195,10 +203,10 @@ def _kl_divergence_with_logits(q_logits, p_logits, weights):
Args:
q_logits: logits for 1st argument of KL divergence shape
[num_timesteps
* batch_size
, num_classes] if num_classes > 2, and
[num_timesteps
* batch_size
] if num_classes == 2.
[
batch_size,
num_timesteps, num_classes] if num_classes > 2, and
[
batch_size,
num_timesteps] if num_classes == 2.
p_logits: logits for 2nd argument of KL divergence with same shape q_logits.
weights: 1-D float tensor with shape [num_timesteps
* batch_size
].
weights: 1-D float tensor with shape [
batch_size,
num_timesteps].
Elements should be 1.0 only on end of sequences
Returns:
...
...
@@ -209,18 +217,19 @@ def _kl_divergence_with_logits(q_logits, p_logits, weights):
q
=
tf
.
nn
.
sigmoid
(
q_logits
)
kl
=
(
-
tf
.
nn
.
sigmoid_cross_entropy_with_logits
(
logits
=
q_logits
,
labels
=
q
)
+
tf
.
nn
.
sigmoid_cross_entropy_with_logits
(
logits
=
p_logits
,
labels
=
q
))
kl
=
tf
.
squeeze
(
kl
)
kl
=
tf
.
squeeze
(
kl
,
2
)
# For softmax regression
else
:
q
=
tf
.
nn
.
softmax
(
q_logits
)
kl
=
tf
.
reduce_sum
(
q
*
(
tf
.
nn
.
log_softmax
(
q_logits
)
-
tf
.
nn
.
log_softmax
(
p_logits
)),
1
)
q
*
(
tf
.
nn
.
log_softmax
(
q_logits
)
-
tf
.
nn
.
log_softmax
(
p_logits
)),
-
1
)
num_labels
=
tf
.
reduce_sum
(
weights
)
num_labels
=
tf
.
where
(
tf
.
equal
(
num_labels
,
0.
),
1.
,
num_labels
)
kl
.
get_shape
().
assert_has_rank
(
1
)
weights
.
get_shape
().
assert_has_rank
(
1
)
kl
.
get_shape
().
assert_has_rank
(
2
)
weights
.
get_shape
().
assert_has_rank
(
2
)
loss
=
tf
.
identity
(
tf
.
reduce_sum
(
weights
*
kl
)
/
num_labels
,
name
=
'kl'
)
return
loss
research/adversarial_text/data/data_utils.py
View file @
af6527c9
...
...
@@ -271,7 +271,7 @@ def build_labeled_sequence(seq, class_label, label_gain=False):
Args:
seq: SequenceWrapper.
class_label:
bool
.
class_label:
integer, starting from 0
.
label_gain: bool. If True, class_label will be put on every timestep and
weight will increase linearly from 0 to 1.
...
...
research/adversarial_text/data/document_generators.py
View file @
af6527c9
...
...
@@ -259,7 +259,7 @@ def dbpedia_documents(dataset='train',
content
=
content
,
is_validation
=
is_validation
,
is_test
=
False
,
label
=
int
(
row
[
0
])
,
label
=
int
(
row
[
0
])
-
1
,
# Labels should start from 0
add_tokens
=
True
)
...
...
research/adversarial_text/evaluate.py
View file @
af6527c9
...
...
@@ -25,7 +25,7 @@ import time
import
tensorflow
as
tf
import
graphs
from
adversarial_text
import
graphs
flags
=
tf
.
app
.
flags
FLAGS
=
flags
.
FLAGS
...
...
@@ -75,7 +75,8 @@ def run_eval(eval_ops, summary_writer, saver):
Returns:
dict<metric name, value>, with value being the average over all examples.
"""
sv
=
tf
.
train
.
Supervisor
(
logdir
=
FLAGS
.
eval_dir
,
saver
=
None
,
summary_op
=
None
)
sv
=
tf
.
train
.
Supervisor
(
logdir
=
FLAGS
.
eval_dir
,
saver
=
None
,
summary_op
=
None
,
summary_writer
=
None
)
with
sv
.
managed_session
(
master
=
FLAGS
.
master
,
start_standard_services
=
False
)
as
sess
:
if
not
restore_from_checkpoint
(
sess
,
saver
):
...
...
@@ -113,6 +114,7 @@ def _log_values(sess, value_ops, summary_writer=None):
if
summary_writer
is
not
None
:
global_step_val
=
sess
.
run
(
tf
.
train
.
get_global_step
())
tf
.
logging
.
info
(
'Finished eval for step '
+
str
(
global_step_val
))
summary_writer
.
add_summary
(
summary
,
global_step_val
)
...
...
research/adversarial_text/graphs.py
View file @
af6527c9
...
...
@@ -24,9 +24,9 @@ import os
import
tensorflow
as
tf
import
adversarial_losses
as
adv_lib
import
inputs
as
inputs_lib
import
layers
as
layers_lib
from
adversarial_text
import
adversarial_losses
as
adv_lib
from
adversarial_text
import
inputs
as
inputs_lib
from
adversarial_text
import
layers
as
layers_lib
flags
=
tf
.
app
.
flags
FLAGS
=
flags
.
FLAGS
...
...
@@ -47,6 +47,8 @@ flags.DEFINE_integer('num_timesteps', 100, 'Number of timesteps for BPTT')
# Model architechture
flags
.
DEFINE_bool
(
'bidir_lstm'
,
False
,
'Whether to build a bidirectional LSTM.'
)
flags
.
DEFINE_bool
(
'single_label'
,
True
,
'Whether the sequence has a single '
'label, for optimization.'
)
flags
.
DEFINE_integer
(
'rnn_num_layers'
,
1
,
'Number of LSTM layers.'
)
flags
.
DEFINE_integer
(
'rnn_cell_size'
,
512
,
'Number of hidden units in the LSTM.'
)
...
...
@@ -181,7 +183,14 @@ class VatxtModel(object):
self
.
tensors
[
'cl_logits'
]
=
logits
self
.
tensors
[
'cl_loss'
]
=
loss
acc
=
layers_lib
.
accuracy
(
logits
,
inputs
.
labels
,
inputs
.
weights
)
if
FLAGS
.
single_label
:
indices
=
tf
.
stack
([
tf
.
range
(
FLAGS
.
batch_size
),
inputs
.
length
-
1
],
1
)
labels
=
tf
.
expand_dims
(
tf
.
gather_nd
(
inputs
.
labels
,
indices
),
1
)
weights
=
tf
.
expand_dims
(
tf
.
gather_nd
(
inputs
.
weights
,
indices
),
1
)
else
:
labels
=
inputs
.
labels
weights
=
inputs
.
weights
acc
=
layers_lib
.
accuracy
(
logits
,
labels
,
weights
)
tf
.
summary
.
scalar
(
'accuracy'
,
acc
)
adv_loss
=
(
self
.
adversarial_loss
()
*
tf
.
constant
(
...
...
@@ -189,11 +198,10 @@ class VatxtModel(object):
tf
.
summary
.
scalar
(
'adversarial_loss'
,
adv_loss
)
total_loss
=
loss
+
adv_loss
tf
.
summary
.
scalar
(
'total_classification_loss'
,
total_loss
)
with
tf
.
control_dependencies
([
inputs
.
save_state
(
next_state
)]):
total_loss
=
tf
.
identity
(
total_loss
)
tf
.
summary
.
scalar
(
'total_classification_loss'
,
total_loss
)
return
total_loss
def
language_model_graph
(
self
,
compute_loss
=
True
):
...
...
@@ -249,10 +257,17 @@ class VatxtModel(object):
_
,
next_state
,
logits
,
_
=
self
.
cl_loss_from_embedding
(
embedded
,
inputs
=
inputs
,
return_intermediates
=
True
)
if
FLAGS
.
single_label
:
indices
=
tf
.
stack
([
tf
.
range
(
FLAGS
.
batch_size
),
inputs
.
length
-
1
],
1
)
labels
=
tf
.
expand_dims
(
tf
.
gather_nd
(
inputs
.
labels
,
indices
),
1
)
weights
=
tf
.
expand_dims
(
tf
.
gather_nd
(
inputs
.
weights
,
indices
),
1
)
else
:
labels
=
inputs
.
labels
weights
=
inputs
.
weights
eval_ops
=
{
'accuracy'
:
tf
.
contrib
.
metrics
.
streaming_accuracy
(
layers_lib
.
predictions
(
logits
),
inputs
.
labels
,
inputs
.
weights
)
layers_lib
.
predictions
(
logits
),
labels
,
weights
)
}
with
tf
.
control_dependencies
([
inputs
.
save_state
(
next_state
)]):
...
...
@@ -286,8 +301,16 @@ class VatxtModel(object):
lstm_out
,
next_state
=
self
.
layers
[
'lstm'
](
embedded
,
inputs
.
state
,
inputs
.
length
)
if
FLAGS
.
single_label
:
indices
=
tf
.
stack
([
tf
.
range
(
FLAGS
.
batch_size
),
inputs
.
length
-
1
],
1
)
lstm_out
=
tf
.
expand_dims
(
tf
.
gather_nd
(
lstm_out
,
indices
),
1
)
labels
=
tf
.
expand_dims
(
tf
.
gather_nd
(
inputs
.
labels
,
indices
),
1
)
weights
=
tf
.
expand_dims
(
tf
.
gather_nd
(
inputs
.
weights
,
indices
),
1
)
else
:
labels
=
inputs
.
labels
weights
=
inputs
.
weights
logits
=
self
.
layers
[
'cl_logits'
](
lstm_out
)
loss
=
layers_lib
.
classification_loss
(
logits
,
inputs
.
labels
,
inputs
.
weights
)
loss
=
layers_lib
.
classification_loss
(
logits
,
labels
,
weights
)
if
return_intermediates
:
return
lstm_out
,
next_state
,
logits
,
loss
...
...
@@ -419,12 +442,12 @@ class VatxtBidirModel(VatxtModel):
tf
.
summary
.
scalar
(
'adversarial_loss'
,
adv_loss
)
total_loss
=
loss
+
adv_loss
tf
.
summary
.
scalar
(
'total_classification_loss'
,
total_loss
)
saves
=
[
inp
.
save_state
(
state
)
for
(
inp
,
state
)
in
zip
(
inputs
,
next_states
)]
with
tf
.
control_dependencies
(
saves
):
total_loss
=
tf
.
identity
(
total_loss
)
tf
.
summary
.
scalar
(
'total_classification_loss'
,
total_loss
)
return
total_loss
def
language_model_graph
(
self
,
compute_loss
=
True
):
...
...
research/adversarial_text/graphs_test.py
View file @
af6527c9
...
...
@@ -29,7 +29,7 @@ import tempfile
import
tensorflow
as
tf
import
graphs
from
adversarial_text
import
graphs
from
adversarial_text.data
import
data_utils
flags
=
tf
.
app
.
flags
...
...
research/adversarial_text/inputs.py
View file @
af6527c9
...
...
@@ -51,27 +51,16 @@ class VatxtInput(object):
batch
.
sequences
[
data_utils
.
SequenceWrapper
.
F_TOKEN_ID
])
self
.
_num_states
=
num_states
# Once the tokens have passed through embedding and LSTM, the output Tensor
# shapes will be time-major, i.e. shape = (time, batch, dim). Here we make
# both weights and labels time-major with a transpose, and then merge the
# time and batch dimensions such that they are both vectors of shape
# (time*batch).
w
=
batch
.
sequences
[
data_utils
.
SequenceWrapper
.
F_WEIGHT
]
w
=
tf
.
transpose
(
w
,
[
1
,
0
])
w
=
tf
.
reshape
(
w
,
[
-
1
])
self
.
_weights
=
w
l
=
batch
.
sequences
[
data_utils
.
SequenceWrapper
.
F_LABEL
]
l
=
tf
.
transpose
(
l
,
[
1
,
0
])
l
=
tf
.
reshape
(
l
,
[
-
1
])
self
.
_labels
=
l
# eos weights
self
.
_eos_weights
=
None
if
eos_id
:
ew
=
tf
.
cast
(
tf
.
equal
(
self
.
_tokens
,
eos_id
),
tf
.
float32
)
ew
=
tf
.
transpose
(
ew
,
[
1
,
0
])
ew
=
tf
.
reshape
(
ew
,
[
-
1
])
self
.
_eos_weights
=
ew
@
property
...
...
research/adversarial_text/layers.py
View file @
af6527c9
...
...
@@ -16,12 +16,11 @@
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
from
six.moves
import
xrange
# pylint: disable=redefined-builtin
# Dependency imports
import
tensorflow
as
tf
K
=
tf
.
contrib
.
keras
K
=
tf
.
keras
def
cl_logits_subgraph
(
layer_sizes
,
input_size
,
num_classes
,
keep_prob
=
1.
):
...
...
@@ -96,7 +95,7 @@ class Embedding(K.layers.Layer):
class
LSTM
(
object
):
"""LSTM layer using
stat
ic_rnn.
"""LSTM layer using
dynam
ic_rnn.
Exposes variables in `trainable_weights` property.
"""
...
...
@@ -120,16 +119,11 @@ class LSTM(object):
])
# shape(x) = (batch_size, num_timesteps, embedding_dim)
# Convert into a time-major list for static_rnn
x
=
tf
.
unstack
(
tf
.
transpose
(
x
,
perm
=
[
1
,
0
,
2
]))
lstm_out
,
next_state
=
tf
.
contrib
.
rnn
.
stat
ic_rnn
(
lstm_out
,
next_state
=
tf
.
nn
.
dynam
ic_rnn
(
cell
,
x
,
initial_state
=
initial_state
,
sequence_length
=
seq_length
)
# Merge time and batch dimensions
# shape(lstm_out) = timesteps * (batch_size, cell_size)
lstm_out
=
tf
.
concat
(
lstm_out
,
0
)
# shape(lstm_out) = (timesteps*batch_size, cell_size)
# shape(lstm_out) = (batch_size, timesteps, cell_size)
if
self
.
keep_prob
<
1.
:
lstm_out
=
tf
.
nn
.
dropout
(
lstm_out
,
self
.
keep_prob
)
...
...
@@ -154,6 +148,7 @@ class SoftmaxLoss(K.layers.Layer):
self
.
num_candidate_samples
=
num_candidate_samples
self
.
vocab_freqs
=
vocab_freqs
super
(
SoftmaxLoss
,
self
).
__init__
(
**
kwargs
)
self
.
multiclass_dense_layer
=
K
.
layers
.
Dense
(
self
.
vocab_size
)
def
build
(
self
,
input_shape
):
input_shape
=
input_shape
[
0
]
...
...
@@ -166,6 +161,7 @@ class SoftmaxLoss(K.layers.Layer):
shape
=
(
self
.
vocab_size
,),
name
=
'lm_lin_b'
,
initializer
=
K
.
initializers
.
glorot_uniform
())
self
.
multiclass_dense_layer
.
build
(
input_shape
)
super
(
SoftmaxLoss
,
self
).
build
(
input_shape
)
...
...
@@ -173,25 +169,30 @@ class SoftmaxLoss(K.layers.Layer):
x
,
labels
,
weights
=
inputs
if
self
.
num_candidate_samples
>
-
1
:
assert
self
.
vocab_freqs
is
not
None
labels
=
tf
.
expand_dims
(
labels
,
-
1
)
labels_reshaped
=
tf
.
reshape
(
labels
,
[
-
1
])
labels_reshaped
=
tf
.
expand_dims
(
labels_reshaped
,
-
1
)
sampled
=
tf
.
nn
.
fixed_unigram_candidate_sampler
(
true_classes
=
labels
,
true_classes
=
labels
_reshaped
,
num_true
=
1
,
num_sampled
=
self
.
num_candidate_samples
,
unique
=
True
,
range_max
=
self
.
vocab_size
,
unigrams
=
self
.
vocab_freqs
)
inputs_reshaped
=
tf
.
reshape
(
x
,
[
-
1
,
int
(
x
.
get_shape
()[
2
])])
lm_loss
=
tf
.
nn
.
sampled_softmax_loss
(
weights
=
tf
.
transpose
(
self
.
lin_w
),
biases
=
self
.
lin_b
,
labels
=
labels
,
inputs
=
x
,
labels
=
labels
_reshaped
,
inputs
=
inputs_reshaped
,
num_sampled
=
self
.
num_candidate_samples
,
num_classes
=
self
.
vocab_size
,
sampled_values
=
sampled
)
lm_loss
=
tf
.
reshape
(
lm_loss
,
[
int
(
x
.
get_shape
()[
0
]),
int
(
x
.
get_shape
()[
1
])])
else
:
logits
=
tf
.
matmul
(
x
,
self
.
lin_w
)
+
self
.
lin_b
logits
=
self
.
multiclass_dense_layer
(
x
)
lm_loss
=
tf
.
nn
.
sparse_softmax_cross_entropy_with_logits
(
logits
=
logits
,
labels
=
labels
)
...
...
@@ -218,7 +219,7 @@ def classification_loss(logits, labels, weights):
# Logistic loss
if
inner_dim
==
1
:
loss
=
tf
.
nn
.
sigmoid_cross_entropy_with_logits
(
logits
=
tf
.
squeeze
(
logits
),
labels
=
tf
.
cast
(
labels
,
tf
.
float32
))
logits
=
tf
.
squeeze
(
logits
,
-
1
),
labels
=
tf
.
cast
(
labels
,
tf
.
float32
))
# Softmax loss
else
:
loss
=
tf
.
nn
.
sparse_softmax_cross_entropy_with_logits
(
...
...
@@ -253,10 +254,10 @@ def predictions(logits):
with
tf
.
name_scope
(
'predictions'
):
# For binary classification
if
inner_dim
==
1
:
pred
=
tf
.
cast
(
tf
.
greater
(
tf
.
squeeze
(
logits
),
0.5
),
tf
.
int64
)
pred
=
tf
.
cast
(
tf
.
greater
(
tf
.
squeeze
(
logits
,
-
1
),
0.5
),
tf
.
int64
)
# For multi-class classification
else
:
pred
=
tf
.
argmax
(
logits
,
1
)
pred
=
tf
.
argmax
(
logits
,
2
)
return
pred
...
...
@@ -355,10 +356,9 @@ def optimize(loss,
opt
.
ready_for_local_init_op
)
else
:
# Non-sync optimizer
variables_averages_op
=
variable_averages
.
apply
(
tvars
)
apply_gradient_op
=
opt
.
apply_gradients
(
grads_and_vars
,
global_step
)
with
tf
.
control_dependencies
([
apply_gradient_op
,
variables_averages_op
]):
train_op
=
tf
.
no_op
(
name
=
'train_op'
)
with
tf
.
control_dependencies
([
apply_gradient_op
]):
train_op
=
variable_averages
.
apply
(
tvars
)
return
train_op
...
...
research/adversarial_text/pretrain.py
View file @
af6527c9
...
...
@@ -27,8 +27,8 @@ from __future__ import print_function
import
tensorflow
as
tf
import
graphs
import
train_utils
from
adversarial_text
import
graphs
from
adversarial_text
import
train_utils
FLAGS
=
tf
.
app
.
flags
.
FLAGS
...
...
research/adversarial_text/train_classifier.py
View file @
af6527c9
...
...
@@ -35,8 +35,8 @@ from __future__ import print_function
import
tensorflow
as
tf
import
graphs
import
train_utils
from
adversarial_text
import
graphs
from
adversarial_text
import
train_utils
flags
=
tf
.
app
.
flags
FLAGS
=
flags
.
FLAGS
...
...
research/adversarial_text/train_utils.py
View file @
af6527c9
...
...
@@ -64,8 +64,8 @@ def run_training(train_op,
sv
=
tf
.
train
.
Supervisor
(
logdir
=
FLAGS
.
train_dir
,
is_chief
=
is_chief
,
save_summaries_secs
=
5
*
6
0
,
save_model_secs
=
5
*
6
0
,
save_summaries_secs
=
3
0
,
save_model_secs
=
3
0
,
local_init_op
=
local_init_op
,
ready_for_local_init_op
=
ready_for_local_init_op
,
global_step
=
global_step
)
...
...
@@ -90,10 +90,9 @@ def run_training(train_op,
global_step_val
=
0
while
not
sv
.
should_stop
()
and
global_step_val
<
FLAGS
.
max_steps
:
global_step_val
=
train_step
(
sess
,
train_op
,
loss
,
global_step
)
sv
.
stop
()
# Final checkpoint
if
is_chief
:
if
is_chief
and
global_step_val
>=
FLAGS
.
max_steps
:
sv
.
saver
.
save
(
sess
,
sv
.
save_path
,
global_step
=
global_step
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment