Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
e8dd2bf3
Commit
e8dd2bf3
authored
Oct 25, 2017
by
Andrew M. Dai
Browse files
Fix adversarial training with recent shape changes.
PiperOrigin-RevId: 173414999
parent
7d921c12
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
25 additions
and
17 deletions
+25
-17
research/adversarial_text/adversarial_losses.py
research/adversarial_text/adversarial_losses.py
+16
-8
research/adversarial_text/graphs.py
research/adversarial_text/graphs.py
+7
-7
research/adversarial_text/layers.py
research/adversarial_text/layers.py
+2
-2
No files found.
research/adversarial_text/adversarial_losses.py
View file @
e8dd2bf3
...
@@ -73,7 +73,7 @@ def virtual_adversarial_loss(logits, embedded, inputs,
...
@@ -73,7 +73,7 @@ def virtual_adversarial_loss(logits, embedded, inputs,
between the new logits and the original logits.
between the new logits and the original logits.
Args:
Args:
logits:
2
-D float Tensor, [num_timesteps
*batch_size
, m], where m=1 if
logits:
3
-D float Tensor, [
batch_size,
num_timesteps, m], where m=1 if
num_classes=2, otherwise m=num_classes.
num_classes=2, otherwise m=num_classes.
embedded: 3-D float Tensor, [batch_size, num_timesteps, embedding_dim].
embedded: 3-D float Tensor, [batch_size, num_timesteps, embedding_dim].
inputs: VatxtInput.
inputs: VatxtInput.
...
@@ -89,6 +89,9 @@ def virtual_adversarial_loss(logits, embedded, inputs,
...
@@ -89,6 +89,9 @@ def virtual_adversarial_loss(logits, embedded, inputs,
# Only care about the KL divergence on the final timestep.
# Only care about the KL divergence on the final timestep.
weights
=
inputs
.
eos_weights
weights
=
inputs
.
eos_weights
assert
weights
is
not
None
assert
weights
is
not
None
if
FLAGS
.
single_label
:
indices
=
tf
.
stack
([
tf
.
range
(
FLAGS
.
batch_size
),
inputs
.
length
-
1
],
1
)
weights
=
tf
.
expand_dims
(
tf
.
gather_nd
(
inputs
.
eos_weights
,
indices
),
1
)
# Initialize perturbation with random noise.
# Initialize perturbation with random noise.
# shape(embedded) = (batch_size, num_timesteps, embedding_dim)
# shape(embedded) = (batch_size, num_timesteps, embedding_dim)
...
@@ -101,6 +104,7 @@ def virtual_adversarial_loss(logits, embedded, inputs,
...
@@ -101,6 +104,7 @@ def virtual_adversarial_loss(logits, embedded, inputs,
for
_
in
xrange
(
FLAGS
.
num_power_iteration
):
for
_
in
xrange
(
FLAGS
.
num_power_iteration
):
d
=
_scale_l2
(
d
=
_scale_l2
(
_mask_by_length
(
d
,
inputs
.
length
),
FLAGS
.
small_constant_for_finite_diff
)
_mask_by_length
(
d
,
inputs
.
length
),
FLAGS
.
small_constant_for_finite_diff
)
d_logits
=
logits_from_embedding_fn
(
embedded
+
d
)
d_logits
=
logits_from_embedding_fn
(
embedded
+
d
)
kl
=
_kl_divergence_with_logits
(
logits
,
d_logits
,
weights
)
kl
=
_kl_divergence_with_logits
(
logits
,
d_logits
,
weights
)
d
,
=
tf
.
gradients
(
d
,
=
tf
.
gradients
(
...
@@ -141,6 +145,9 @@ def virtual_adversarial_loss_bidir(logits, embedded, inputs,
...
@@ -141,6 +145,9 @@ def virtual_adversarial_loss_bidir(logits, embedded, inputs,
logits
=
tf
.
stop_gradient
(
logits
)
logits
=
tf
.
stop_gradient
(
logits
)
f_inputs
,
_
=
inputs
f_inputs
,
_
=
inputs
weights
=
f_inputs
.
eos_weights
weights
=
f_inputs
.
eos_weights
if
FLAGS
.
single_label
:
indices
=
tf
.
stack
([
tf
.
range
(
FLAGS
.
batch_size
),
f_inputs
.
length
-
1
],
1
)
weights
=
tf
.
expand_dims
(
tf
.
gather_nd
(
f_inputs
.
eos_weights
,
indices
),
1
)
assert
weights
is
not
None
assert
weights
is
not
None
perturbs
=
[
perturbs
=
[
...
@@ -194,10 +201,10 @@ def _kl_divergence_with_logits(q_logits, p_logits, weights):
...
@@ -194,10 +201,10 @@ def _kl_divergence_with_logits(q_logits, p_logits, weights):
Args:
Args:
q_logits: logits for 1st argument of KL divergence shape
q_logits: logits for 1st argument of KL divergence shape
[num_timesteps
* batch_size
, num_classes] if num_classes > 2, and
[
batch_size,
num_timesteps, num_classes] if num_classes > 2, and
[num_timesteps
* batch_size
] if num_classes == 2.
[
batch_size,
num_timesteps] if num_classes == 2.
p_logits: logits for 2nd argument of KL divergence with same shape q_logits.
p_logits: logits for 2nd argument of KL divergence with same shape q_logits.
weights: 1-D float tensor with shape [num_timesteps
* batch_size
].
weights: 1-D float tensor with shape [
batch_size,
num_timesteps].
Elements should be 1.0 only on end of sequences
Elements should be 1.0 only on end of sequences
Returns:
Returns:
...
@@ -208,18 +215,19 @@ def _kl_divergence_with_logits(q_logits, p_logits, weights):
...
@@ -208,18 +215,19 @@ def _kl_divergence_with_logits(q_logits, p_logits, weights):
q
=
tf
.
nn
.
sigmoid
(
q_logits
)
q
=
tf
.
nn
.
sigmoid
(
q_logits
)
kl
=
(
-
tf
.
nn
.
sigmoid_cross_entropy_with_logits
(
logits
=
q_logits
,
labels
=
q
)
+
kl
=
(
-
tf
.
nn
.
sigmoid_cross_entropy_with_logits
(
logits
=
q_logits
,
labels
=
q
)
+
tf
.
nn
.
sigmoid_cross_entropy_with_logits
(
logits
=
p_logits
,
labels
=
q
))
tf
.
nn
.
sigmoid_cross_entropy_with_logits
(
logits
=
p_logits
,
labels
=
q
))
kl
=
tf
.
squeeze
(
kl
)
kl
=
tf
.
squeeze
(
kl
,
2
)
# For softmax regression
# For softmax regression
else
:
else
:
q
=
tf
.
nn
.
softmax
(
q_logits
)
q
=
tf
.
nn
.
softmax
(
q_logits
)
kl
=
tf
.
reduce_sum
(
kl
=
tf
.
reduce_sum
(
q
*
(
tf
.
nn
.
log_softmax
(
q_logits
)
-
tf
.
nn
.
log_softmax
(
p_logits
)),
1
)
q
*
(
tf
.
nn
.
log_softmax
(
q_logits
)
-
tf
.
nn
.
log_softmax
(
p_logits
)),
-
1
)
num_labels
=
tf
.
reduce_sum
(
weights
)
num_labels
=
tf
.
reduce_sum
(
weights
)
num_labels
=
tf
.
where
(
tf
.
equal
(
num_labels
,
0.
),
1.
,
num_labels
)
num_labels
=
tf
.
where
(
tf
.
equal
(
num_labels
,
0.
),
1.
,
num_labels
)
kl
.
get_shape
().
assert_has_rank
(
1
)
kl
.
get_shape
().
assert_has_rank
(
2
)
weights
.
get_shape
().
assert_has_rank
(
1
)
weights
.
get_shape
().
assert_has_rank
(
2
)
loss
=
tf
.
identity
(
tf
.
reduce_sum
(
weights
*
kl
)
/
num_labels
,
name
=
'kl'
)
loss
=
tf
.
identity
(
tf
.
reduce_sum
(
weights
*
kl
)
/
num_labels
,
name
=
'kl'
)
return
loss
return
loss
research/adversarial_text/graphs.py
View file @
e8dd2bf3
...
@@ -185,8 +185,8 @@ class VatxtModel(object):
...
@@ -185,8 +185,8 @@ class VatxtModel(object):
if
FLAGS
.
single_label
:
if
FLAGS
.
single_label
:
indices
=
tf
.
stack
([
tf
.
range
(
FLAGS
.
batch_size
),
inputs
.
length
-
1
],
1
)
indices
=
tf
.
stack
([
tf
.
range
(
FLAGS
.
batch_size
),
inputs
.
length
-
1
],
1
)
labels
=
tf
.
gather_nd
(
inputs
.
labels
,
indices
)
labels
=
tf
.
expand_dims
(
tf
.
gather_nd
(
inputs
.
labels
,
indices
)
,
1
)
weights
=
tf
.
gather_nd
(
inputs
.
weights
,
indices
)
weights
=
tf
.
expand_dims
(
tf
.
gather_nd
(
inputs
.
weights
,
indices
)
,
1
)
else
:
else
:
labels
=
inputs
.
labels
labels
=
inputs
.
labels
weights
=
inputs
.
weights
weights
=
inputs
.
weights
...
@@ -259,8 +259,8 @@ class VatxtModel(object):
...
@@ -259,8 +259,8 @@ class VatxtModel(object):
if
FLAGS
.
single_label
:
if
FLAGS
.
single_label
:
indices
=
tf
.
stack
([
tf
.
range
(
FLAGS
.
batch_size
),
inputs
.
length
-
1
],
1
)
indices
=
tf
.
stack
([
tf
.
range
(
FLAGS
.
batch_size
),
inputs
.
length
-
1
],
1
)
labels
=
tf
.
gather_nd
(
inputs
.
labels
,
indices
)
labels
=
tf
.
expand_dims
(
tf
.
gather_nd
(
inputs
.
labels
,
indices
)
,
1
)
weights
=
tf
.
gather_nd
(
inputs
.
weights
,
indices
)
weights
=
tf
.
expand_dims
(
tf
.
gather_nd
(
inputs
.
weights
,
indices
)
,
1
)
else
:
else
:
labels
=
inputs
.
labels
labels
=
inputs
.
labels
weights
=
inputs
.
weights
weights
=
inputs
.
weights
...
@@ -303,9 +303,9 @@ class VatxtModel(object):
...
@@ -303,9 +303,9 @@ class VatxtModel(object):
inputs
.
length
)
inputs
.
length
)
if
FLAGS
.
single_label
:
if
FLAGS
.
single_label
:
indices
=
tf
.
stack
([
tf
.
range
(
FLAGS
.
batch_size
),
inputs
.
length
-
1
],
1
)
indices
=
tf
.
stack
([
tf
.
range
(
FLAGS
.
batch_size
),
inputs
.
length
-
1
],
1
)
lstm_out
=
tf
.
gather_nd
(
lstm_out
,
indices
)
lstm_out
=
tf
.
expand_dims
(
tf
.
gather_nd
(
lstm_out
,
indices
)
,
1
)
labels
=
tf
.
gather_nd
(
inputs
.
labels
,
indices
)
labels
=
tf
.
expand_dims
(
tf
.
gather_nd
(
inputs
.
labels
,
indices
)
,
1
)
weights
=
tf
.
gather_nd
(
inputs
.
weights
,
indices
)
weights
=
tf
.
expand_dims
(
tf
.
gather_nd
(
inputs
.
weights
,
indices
)
,
1
)
else
:
else
:
labels
=
inputs
.
labels
labels
=
inputs
.
labels
weights
=
inputs
.
weights
weights
=
inputs
.
weights
...
...
research/adversarial_text/layers.py
View file @
e8dd2bf3
...
@@ -217,7 +217,7 @@ def classification_loss(logits, labels, weights):
...
@@ -217,7 +217,7 @@ def classification_loss(logits, labels, weights):
# Logistic loss
# Logistic loss
if
inner_dim
==
1
:
if
inner_dim
==
1
:
loss
=
tf
.
nn
.
sigmoid_cross_entropy_with_logits
(
loss
=
tf
.
nn
.
sigmoid_cross_entropy_with_logits
(
logits
=
tf
.
squeeze
(
logits
),
labels
=
tf
.
cast
(
labels
,
tf
.
float32
))
logits
=
tf
.
squeeze
(
logits
,
-
1
),
labels
=
tf
.
cast
(
labels
,
tf
.
float32
))
# Softmax loss
# Softmax loss
else
:
else
:
loss
=
tf
.
nn
.
sparse_softmax_cross_entropy_with_logits
(
loss
=
tf
.
nn
.
sparse_softmax_cross_entropy_with_logits
(
...
@@ -252,7 +252,7 @@ def predictions(logits):
...
@@ -252,7 +252,7 @@ def predictions(logits):
with
tf
.
name_scope
(
'predictions'
):
with
tf
.
name_scope
(
'predictions'
):
# For binary classification
# For binary classification
if
inner_dim
==
1
:
if
inner_dim
==
1
:
pred
=
tf
.
cast
(
tf
.
greater
(
tf
.
squeeze
(
logits
),
0.5
),
tf
.
int64
)
pred
=
tf
.
cast
(
tf
.
greater
(
tf
.
squeeze
(
logits
,
-
1
),
0.5
),
tf
.
int64
)
# For multi-class classification
# For multi-class classification
else
:
else
:
pred
=
tf
.
argmax
(
logits
,
1
)
pred
=
tf
.
argmax
(
logits
,
1
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment