Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
a97304d5
Commit
a97304d5
authored
May 03, 2017
by
Ryan Sepassi
Browse files
Updates to adversarial_text model
parent
4cc1fa0f
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
35 additions
and
20 deletions
+35
-20
adversarial_text/adversarial_losses.py
adversarial_text/adversarial_losses.py
+14
-5
adversarial_text/evaluate.py
adversarial_text/evaluate.py
+16
-9
adversarial_text/layers.py
adversarial_text/layers.py
+5
-6
No files found.
adversarial_text/adversarial_losses.py
View file @
a97304d5
...
...
@@ -83,8 +83,10 @@ def virtual_adversarial_loss(logits, embedded, inputs,
"""
# Stop gradient of logits. See https://arxiv.org/abs/1507.00677 for details.
logits
=
tf
.
stop_gradient
(
logits
)
# Only care about the KL divergence on the final timestep.
weights
=
_end_of_seq_mask
(
inputs
.
labels
)
# Initialize perturbation with random noise.
# shape(embedded) = (batch_size, num_timesteps, embedding_dim)
d
=
_mask_by_length
(
tf
.
random_normal
(
shape
=
tf
.
shape
(
embedded
)),
inputs
.
length
)
...
...
@@ -173,11 +175,15 @@ def _mask_by_length(t, length):
def
_scale_l2
(
x
,
norm_length
):
# shape(x) = (batch, num_timesteps, d)
x
/=
(
1e-12
+
tf
.
reduce_max
(
tf
.
abs
(
x
),
2
,
keep_dims
=
True
))
x_2
=
tf
.
reduce_sum
(
tf
.
pow
(
x
,
2
),
2
,
keep_dims
=
True
)
x
/=
tf
.
sqrt
(
1e-6
+
x_2
)
return
norm_length
*
x
# Divide x by max(abs(x)) for a numerically stable L2 norm.
# 2norm(x) = a * 2norm(x/a)
# Scale over the full sequence, dims (1, 2)
alpha
=
tf
.
reduce_max
(
tf
.
abs
(
x
),
(
1
,
2
),
keep_dims
=
True
)
+
1e-12
l2_norm
=
alpha
*
tf
.
sqrt
(
tf
.
reduce_sum
(
tf
.
pow
(
x
/
alpha
,
2
),
(
1
,
2
),
keep_dims
=
True
)
+
1e-6
)
x_unit
=
x
/
l2_norm
return
norm_length
*
x_unit
def
_end_of_seq_mask
(
tokens
):
...
...
@@ -225,5 +231,8 @@ def _kl_divergence_with_logits(q_logits, p_logits, weights):
num_labels
=
tf
.
reduce_sum
(
weights
)
num_labels
=
tf
.
where
(
tf
.
equal
(
num_labels
,
0.
),
1.
,
num_labels
)
loss
=
tf
.
identity
(
tf
.
reduce_sum
(
weights
*
kl
)
/
num_labels
,
name
=
'kl'
)
kl
.
get_shape
().
assert_has_rank
(
2
)
weights
.
get_shape
().
assert_has_rank
(
1
)
loss
=
tf
.
identity
(
tf
.
reduce_sum
(
tf
.
expand_dims
(
weights
,
-
1
)
*
kl
)
/
num_labels
,
name
=
'kl'
)
return
loss
adversarial_text/evaluate.py
View file @
a97304d5
...
...
@@ -84,28 +84,35 @@ def run_eval(eval_ops, summary_writer, saver):
metric_names
,
ops
=
zip
(
*
eval_ops
.
items
())
value_ops
,
update_ops
=
zip
(
*
ops
)
value_ops_dict
=
dict
(
zip
(
metric_names
,
value_ops
))
# Run update ops
num_batches
=
int
(
math
.
ceil
(
FLAGS
.
num_examples
/
FLAGS
.
batch_size
))
tf
.
logging
.
info
(
'Running %d batches for evaluation.'
,
num_batches
)
for
i
in
range
(
num_batches
):
if
(
i
+
1
)
%
10
==
0
:
tf
.
logging
.
info
(
'Running batch %d/%d...'
,
i
+
1
,
num_batches
)
if
(
i
+
1
)
%
50
==
0
:
_log_values
(
sess
,
value_ops_dict
)
sess
.
run
(
update_ops
)
_log_values
(
sess
,
value_ops_dict
,
summary_writer
=
summary_writer
)
def
_log_values
(
sess
,
value_ops
,
summary_writer
=
None
):
metric_names
,
value_ops
=
zip
(
*
value_ops
.
items
())
values
=
sess
.
run
(
value_ops
)
metric_values
=
dict
(
zip
(
metric_names
,
values
))
tf
.
logging
.
info
(
'Eval metric values:'
)
summary
=
tf
.
summary
.
Summary
()
for
name
,
val
in
metric_
values
.
items
(
):
for
name
,
val
in
zip
(
metric_
names
,
values
):
summary
.
value
.
add
(
tag
=
name
,
simple_value
=
val
)
tf
.
logging
.
info
(
'%s = %.3f'
,
name
,
val
)
if
summary_writer
is
not
None
:
global_step_val
=
sess
.
run
(
tf
.
train
.
get_global_step
())
summary_writer
.
add_summary
(
summary
,
global_step_val
)
return
metric_values
def
main
(
_
):
tf
.
logging
.
set_verbosity
(
tf
.
logging
.
INFO
)
...
...
adversarial_text/layers.py
View file @
a97304d5
...
...
@@ -81,11 +81,10 @@ class Embedding(K.layers.Layer):
def
_normalize
(
self
,
emb
):
weights
=
self
.
vocab_freqs
/
tf
.
reduce_sum
(
self
.
vocab_freqs
)
emb
-=
tf
.
reduce_sum
(
weights
*
emb
,
0
,
keep_dims
=
True
)
emb
/=
tf
.
sqrt
(
1e-6
+
tf
.
reduce_sum
(
weights
*
tf
.
pow
(
emb
,
2.
),
0
,
keep_dims
=
True
))
return
emb
mean
=
tf
.
reduce_sum
(
weights
*
emb
,
0
,
keep_dims
=
True
)
var
=
tf
.
reduce_sum
(
weights
*
tf
.
pow
(
emb
-
mean
,
2.
),
0
,
keep_dims
=
True
)
stddev
=
tf
.
sqrt
(
1e-6
+
var
)
return
(
emb
-
mean
)
/
stddev
class
LSTM
(
object
):
...
...
@@ -201,7 +200,7 @@ def classification_loss(logits, labels, weights):
logits: 2-D [timesteps*batch_size, m] float tensor, where m=1 if
num_classes=2, otherwise m=num_classes.
labels: 1-D [timesteps*batch_size] integer tensor.
weights:
2
-D [timesteps*batch_size] float tensor.
weights:
1
-D [timesteps*batch_size] float tensor.
Returns:
Loss scalar of type float.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment