Unverified Commit d903d5ed authored by Andrew M Dai's avatar Andrew M Dai Committed by GitHub
Browse files

Merge pull request #3414 from a-dai/master

Fix github issue #3269 where the accuracy is wrongly underestimated for binary classification and build issue #2784
parents d6d08682 900ea814
licenses(["notice"]) # Apache 2.0
exports_files(["LICENSE"])
# Binaries
# ==============================================================================
py_binary(
name = "evaluate",
srcs = ["evaluate.py"],
deps = [
":graphs",
# google3 file dep,
# tensorflow internal dep,
],
)
py_binary(
name = "train_classifier",
srcs = ["train_classifier.py"],
deps = [
":graphs",
":train_utils",
# google3 file dep,
# tensorflow internal dep,
],
)
py_binary(
name = "pretrain",
srcs = [
"pretrain.py",
],
deps = [
":graphs",
":train_utils",
# google3 file dep,
# tensorflow internal gpu deps
# tensorflow internal dep,
],
)
# Libraries
# ==============================================================================
py_library(
name = "graphs",
srcs = ["graphs.py"],
deps = [
":adversarial_losses",
":inputs",
":layers",
# tensorflow dep,
],
)
py_library(
name = "adversarial_losses",
srcs = ["adversarial_losses.py"],
deps = [
# tensorflow dep,
],
)
py_library(
name = "inputs",
srcs = ["inputs.py"],
deps = [
# tensorflow dep,
"//adversarial_text/data:data_utils",
],
)
py_library(
name = "layers",
srcs = ["layers.py"],
deps = [
# tensorflow dep,
],
)
py_library(
name = "train_utils",
srcs = ["train_utils.py"],
deps = [
# numpy dep,
# tensorflow dep,
],
)
# Tests
# ==============================================================================
py_test(
name = "graphs_test",
size = "large",
srcs = ["graphs_test.py"],
deps = [
":graphs",
# tensorflow dep,
"//adversarial_text/data:data_utils",
],
)
...@@ -4,7 +4,6 @@ Code for [*Adversarial Training Methods for Semi-Supervised Text Classification* ...@@ -4,7 +4,6 @@ Code for [*Adversarial Training Methods for Semi-Supervised Text Classification*
## Requirements ## Requirements
* Bazel ([install](https://bazel.build/versions/master/docs/install.html))
* TensorFlow >= v1.1 * TensorFlow >= v1.1
## End-to-end IMDB Sentiment Classification ## End-to-end IMDB Sentiment Classification
...@@ -23,7 +22,7 @@ The directory `/tmp/aclImdb` contains the raw IMDB data. ...@@ -23,7 +22,7 @@ The directory `/tmp/aclImdb` contains the raw IMDB data.
``` ```
$ IMDB_DATA_DIR=/tmp/imdb $ IMDB_DATA_DIR=/tmp/imdb
$ bazel run data:gen_vocab -- \ $ python gen_vocab.py -- \
--output_dir=$IMDB_DATA_DIR \ --output_dir=$IMDB_DATA_DIR \
--dataset=imdb \ --dataset=imdb \
--imdb_input_dir=/tmp/aclImdb \ --imdb_input_dir=/tmp/aclImdb \
...@@ -35,7 +34,7 @@ Vocabulary and frequency files will be generated in `$IMDB_DATA_DIR`. ...@@ -35,7 +34,7 @@ Vocabulary and frequency files will be generated in `$IMDB_DATA_DIR`.
###  Generate training, validation, and test data ###  Generate training, validation, and test data
``` ```
$ bazel run data:gen_data -- \ $ python gen_data.py -- \
--output_dir=$IMDB_DATA_DIR \ --output_dir=$IMDB_DATA_DIR \
--dataset=imdb \ --dataset=imdb \
--imdb_input_dir=/tmp/aclImdb \ --imdb_input_dir=/tmp/aclImdb \
...@@ -49,7 +48,7 @@ $ bazel run data:gen_data -- \ ...@@ -49,7 +48,7 @@ $ bazel run data:gen_data -- \
``` ```
$ PRETRAIN_DIR=/tmp/models/imdb_pretrain $ PRETRAIN_DIR=/tmp/models/imdb_pretrain
$ bazel run :pretrain -- \ $ python pretrain.py -- \
--train_dir=$PRETRAIN_DIR \ --train_dir=$PRETRAIN_DIR \
--data_dir=$IMDB_DATA_DIR \ --data_dir=$IMDB_DATA_DIR \
--vocab_size=86934 \ --vocab_size=86934 \
...@@ -77,7 +76,7 @@ training and classification. ...@@ -77,7 +76,7 @@ training and classification.
``` ```
$ TRAIN_DIR=/tmp/models/imdb_classify $ TRAIN_DIR=/tmp/models/imdb_classify
$ bazel run :train_classifier -- \ $ python train_classifier.py -- \
--train_dir=$TRAIN_DIR \ --train_dir=$TRAIN_DIR \
--pretrained_model_dir=$PRETRAIN_DIR \ --pretrained_model_dir=$PRETRAIN_DIR \
--data_dir=$IMDB_DATA_DIR \ --data_dir=$IMDB_DATA_DIR \
...@@ -102,7 +101,7 @@ $ bazel run :train_classifier -- \ ...@@ -102,7 +101,7 @@ $ bazel run :train_classifier -- \
``` ```
$ EVAL_DIR=/tmp/models/imdb_eval $ EVAL_DIR=/tmp/models/imdb_eval
$ bazel run :evaluate -- \ $ python evaluate.py -- \
--eval_dir=$EVAL_DIR \ --eval_dir=$EVAL_DIR \
--checkpoint_dir=$TRAIN_DIR \ --checkpoint_dir=$TRAIN_DIR \
--eval_data=test \ --eval_data=test \
...@@ -145,8 +144,8 @@ Flags particular to each job are defined in the main binary files. ...@@ -145,8 +144,8 @@ Flags particular to each job are defined in the main binary files.
### Data Generation ### Data Generation
* Vocabulary generation: [`gen_vocab.py`](https://github.com/tensorflow/models/tree/master/research/adversarial_text/data/gen_vocab.py) * Vocabulary generation: [`gen_vocab.py`](https://github.com/tensorflow/models/tree/master/research/adversarial_text/gen_vocab.py)
* Data generation: [`gen_data.py`](https://github.com/tensorflow/models/tree/master/research/adversarial_text/data/gen_data.py) * Data generation: [`gen_data.py`](https://github.com/tensorflow/models/tree/master/research/adversarial_text/gen_data.py)
Command-line flags defined in [`document_generators.py`](https://github.com/tensorflow/models/tree/master/research/adversarial_text/data/document_generators.py) Command-line flags defined in [`document_generators.py`](https://github.com/tensorflow/models/tree/master/research/adversarial_text/data/document_generators.py)
control which dataset is processed and how. control which dataset is processed and how.
...@@ -154,4 +153,4 @@ control which dataset is processed and how. ...@@ -154,4 +153,4 @@ control which dataset is processed and how.
## Contact for Issues ## Contact for Issues
* Ryan Sepassi, @rsepassi * Ryan Sepassi, @rsepassi
* Andrew M. Dai, @a-dai * Andrew M. Dai, @a-dai <adai@google.com>
licenses(["notice"]) # Apache 2.0
package(
default_visibility = [
"//adversarial_text:__subpackages__",
],
)
py_binary(
name = "gen_vocab",
srcs = ["gen_vocab.py"],
deps = [
":data_utils",
":document_generators",
# tensorflow dep,
],
)
py_binary(
name = "gen_data",
srcs = ["gen_data.py"],
deps = [
":data_utils",
":document_generators",
# tensorflow dep,
],
)
py_library(
name = "document_generators",
srcs = ["document_generators.py"],
deps = [
# tensorflow dep,
],
)
py_library(
name = "data_utils",
srcs = ["data_utils.py"],
deps = [
# tensorflow dep,
],
)
py_test(
name = "data_utils_test",
srcs = ["data_utils_test.py"],
deps = [
":data_utils",
# tensorflow dep,
],
)
...@@ -21,7 +21,7 @@ from __future__ import print_function ...@@ -21,7 +21,7 @@ from __future__ import print_function
import tensorflow as tf import tensorflow as tf
from adversarial_text.data import data_utils from data import data_utils
data = data_utils data = data_utils
......
...@@ -26,7 +26,7 @@ import random ...@@ -26,7 +26,7 @@ import random
import tensorflow as tf import tensorflow as tf
from adversarial_text.data import data_utils from data import data_utils
flags = tf.app.flags flags = tf.app.flags
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
......
...@@ -25,7 +25,7 @@ import time ...@@ -25,7 +25,7 @@ import time
import tensorflow as tf import tensorflow as tf
from adversarial_text import graphs import graphs
flags = tf.app.flags flags = tf.app.flags
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
......
...@@ -34,8 +34,8 @@ import string ...@@ -34,8 +34,8 @@ import string
import tensorflow as tf import tensorflow as tf
from adversarial_text.data import data_utils from data import data_utils
from adversarial_text.data import document_generators from data import document_generators
data = data_utils data = data_utils
flags = tf.app.flags flags = tf.app.flags
......
...@@ -23,8 +23,8 @@ from collections import defaultdict ...@@ -23,8 +23,8 @@ from collections import defaultdict
import tensorflow as tf import tensorflow as tf
from adversarial_text.data import data_utils from data import data_utils
from adversarial_text.data import document_generators from data import document_generators
flags = tf.app.flags flags = tf.app.flags
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
......
...@@ -24,9 +24,9 @@ import os ...@@ -24,9 +24,9 @@ import os
import tensorflow as tf import tensorflow as tf
from adversarial_text import adversarial_losses as adv_lib import adversarial_losses as adv_lib
from adversarial_text import inputs as inputs_lib import inputs as inputs_lib
from adversarial_text import layers as layers_lib import layers as layers_lib
flags = tf.app.flags flags = tf.app.flags
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
......
...@@ -29,8 +29,8 @@ import tempfile ...@@ -29,8 +29,8 @@ import tempfile
import tensorflow as tf import tensorflow as tf
from adversarial_text import graphs import graphs
from adversarial_text.data import data_utils from data import data_utils
flags = tf.app.flags flags = tf.app.flags
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
......
...@@ -24,7 +24,7 @@ import os ...@@ -24,7 +24,7 @@ import os
import tensorflow as tf import tensorflow as tf
from adversarial_text.data import data_utils from data import data_utils
class VatxtInput(object): class VatxtInput(object):
......
...@@ -254,7 +254,7 @@ def predictions(logits): ...@@ -254,7 +254,7 @@ def predictions(logits):
with tf.name_scope('predictions'): with tf.name_scope('predictions'):
# For binary classification # For binary classification
if inner_dim == 1: if inner_dim == 1:
pred = tf.cast(tf.greater(tf.squeeze(logits, -1), 0.5), tf.int64) pred = tf.cast(tf.greater(tf.squeeze(logits, -1), 0.), tf.int64)
# For multi-class classification # For multi-class classification
else: else:
pred = tf.argmax(logits, 2) pred = tf.argmax(logits, 2)
......
...@@ -27,8 +27,8 @@ from __future__ import print_function ...@@ -27,8 +27,8 @@ from __future__ import print_function
import tensorflow as tf import tensorflow as tf
from adversarial_text import graphs import graphs
from adversarial_text import train_utils import train_utils
FLAGS = tf.app.flags.FLAGS FLAGS = tf.app.flags.FLAGS
......
...@@ -35,8 +35,8 @@ from __future__ import print_function ...@@ -35,8 +35,8 @@ from __future__ import print_function
import tensorflow as tf import tensorflow as tf
from adversarial_text import graphs import graphs
from adversarial_text import train_utils import train_utils
flags = tf.app.flags flags = tf.app.flags
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment