Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
9f0a567d
Commit
9f0a567d
authored
Feb 14, 2018
by
Andrew M. Dai
Browse files
Fixes to multiclass training. Change DBpedia data generation to generate labels starting from 0.
PiperOrigin-RevId: 185783053
parent
9e30188e
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
19 additions
and
15 deletions
+19
-15
research/adversarial_text/BUILD
research/adversarial_text/BUILD
+4
-3
research/adversarial_text/README.md
research/adversarial_text/README.md
+1
-0
research/adversarial_text/adversarial_losses.py
research/adversarial_text/adversarial_losses.py
+2
-0
research/adversarial_text/data/data_utils.py
research/adversarial_text/data/data_utils.py
+1
-1
research/adversarial_text/data/document_generators.py
research/adversarial_text/data/document_generators.py
+1
-1
research/adversarial_text/layers.py
research/adversarial_text/layers.py
+7
-6
research/adversarial_text/train_utils.py
research/adversarial_text/train_utils.py
+3
-4
No files found.
research/adversarial_text/BUILD
View file @
9f0a567d
...
...
@@ -10,7 +10,7 @@ py_binary(
deps
=
[
":graphs"
,
# google3 file dep,
# tensorflow dep,
# tensorflow
internal
dep,
],
)
...
...
@@ -21,7 +21,7 @@ py_binary(
":graphs"
,
":train_utils"
,
# google3 file dep,
# tensorflow dep,
# tensorflow
internal
dep,
],
)
...
...
@@ -34,7 +34,8 @@ py_binary(
":graphs"
,
":train_utils"
,
# google3 file dep,
# tensorflow dep,
# tensorflow internal gpu deps
# tensorflow internal dep,
],
)
...
...
research/adversarial_text/README.md
View file @
9f0a567d
...
...
@@ -154,3 +154,4 @@ control which dataset is processed and how.
## Contact for Issues
*
Ryan Sepassi, @rsepassi
*
Andrew M. Dai, @a-dai
research/adversarial_text/adversarial_losses.py
View file @
9f0a567d
...
...
@@ -38,6 +38,8 @@ flags.DEFINE_float('small_constant_for_finite_diff', 1e-1,
# Parameters for building the graph
flags
.
DEFINE_string
(
'adv_training_method'
,
None
,
'The flag which specifies training method. '
'"" : non-adversarial training (e.g. for running the '
' semi-supervised sequence learning model) '
'"rp" : random perturbation training '
'"at" : adversarial training '
'"vat" : virtual adversarial training '
...
...
research/adversarial_text/data/data_utils.py
View file @
9f0a567d
...
...
@@ -271,7 +271,7 @@ def build_labeled_sequence(seq, class_label, label_gain=False):
Args:
seq: SequenceWrapper.
class_label:
bool
.
class_label:
integer, starting from 0
.
label_gain: bool. If True, class_label will be put on every timestep and
weight will increase linearly from 0 to 1.
...
...
research/adversarial_text/data/document_generators.py
View file @
9f0a567d
...
...
@@ -259,7 +259,7 @@ def dbpedia_documents(dataset='train',
content
=
content
,
is_validation
=
is_validation
,
is_test
=
False
,
label
=
int
(
row
[
0
])
,
label
=
int
(
row
[
0
])
-
1
,
# Labels should start from 0
add_tokens
=
True
)
...
...
research/adversarial_text/layers.py
View file @
9f0a567d
...
...
@@ -20,7 +20,7 @@ from __future__ import print_function
# Dependency imports
import
tensorflow
as
tf
K
=
tf
.
contrib
.
keras
K
=
tf
.
keras
def
cl_logits_subgraph
(
layer_sizes
,
input_size
,
num_classes
,
keep_prob
=
1.
):
...
...
@@ -148,6 +148,7 @@ class SoftmaxLoss(K.layers.Layer):
self
.
num_candidate_samples
=
num_candidate_samples
self
.
vocab_freqs
=
vocab_freqs
super
(
SoftmaxLoss
,
self
).
__init__
(
**
kwargs
)
self
.
multiclass_dense_layer
=
K
.
layers
.
Dense
(
self
.
vocab_size
)
def
build
(
self
,
input_shape
):
input_shape
=
input_shape
[
0
]
...
...
@@ -160,6 +161,7 @@ class SoftmaxLoss(K.layers.Layer):
shape
=
(
self
.
vocab_size
,),
name
=
'lm_lin_b'
,
initializer
=
K
.
initializers
.
glorot_uniform
())
self
.
multiclass_dense_layer
.
build
(
input_shape
)
super
(
SoftmaxLoss
,
self
).
build
(
input_shape
)
...
...
@@ -190,7 +192,7 @@ class SoftmaxLoss(K.layers.Layer):
lm_loss
,
[
int
(
x
.
get_shape
()[
0
]),
int
(
x
.
get_shape
()[
1
])])
else
:
logits
=
tf
.
matmul
(
x
,
self
.
lin_w
)
+
self
.
lin_b
logits
=
self
.
multiclass_dense_layer
(
x
)
lm_loss
=
tf
.
nn
.
sparse_softmax_cross_entropy_with_logits
(
logits
=
logits
,
labels
=
labels
)
...
...
@@ -255,7 +257,7 @@ def predictions(logits):
pred
=
tf
.
cast
(
tf
.
greater
(
tf
.
squeeze
(
logits
,
-
1
),
0.5
),
tf
.
int64
)
# For multi-class classification
else
:
pred
=
tf
.
argmax
(
logits
,
1
)
pred
=
tf
.
argmax
(
logits
,
2
)
return
pred
...
...
@@ -354,10 +356,9 @@ def optimize(loss,
opt
.
ready_for_local_init_op
)
else
:
# Non-sync optimizer
variables_averages_op
=
variable_averages
.
apply
(
tvars
)
apply_gradient_op
=
opt
.
apply_gradients
(
grads_and_vars
,
global_step
)
with
tf
.
control_dependencies
([
apply_gradient_op
,
variables_averages_op
]):
train_op
=
tf
.
no_op
(
name
=
'train_op'
)
with
tf
.
control_dependencies
([
apply_gradient_op
]):
train_op
=
variable_averages
.
apply
(
tvars
)
return
train_op
...
...
research/adversarial_text/train_utils.py
View file @
9f0a567d
...
...
@@ -64,8 +64,8 @@ def run_training(train_op,
sv
=
tf
.
train
.
Supervisor
(
logdir
=
FLAGS
.
train_dir
,
is_chief
=
is_chief
,
save_summaries_secs
=
5
*
6
0
,
save_model_secs
=
5
*
6
0
,
save_summaries_secs
=
3
0
,
save_model_secs
=
3
0
,
local_init_op
=
local_init_op
,
ready_for_local_init_op
=
ready_for_local_init_op
,
global_step
=
global_step
)
...
...
@@ -90,10 +90,9 @@ def run_training(train_op,
global_step_val
=
0
while
not
sv
.
should_stop
()
and
global_step_val
<
FLAGS
.
max_steps
:
global_step_val
=
train_step
(
sess
,
train_op
,
loss
,
global_step
)
sv
.
stop
()
# Final checkpoint
if
is_chief
:
if
is_chief
and
global_step_val
>=
FLAGS
.
max_steps
:
sv
.
saver
.
save
(
sess
,
sv
.
save_path
,
global_step
=
global_step
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment