"src/array/vscode:/vscode.git/clone" did not exist on "1506560e1ecc6a2c56b51e23766632a9326aeeca"
Unverified Commit d903d5ed authored by Andrew M Dai's avatar Andrew M Dai Committed by GitHub
Browse files

Merge pull request #3414 from a-dai/master

Fix github issue #3269 where the accuracy is wrongly underestimated for binary classification and build issue #2784
parents d6d08682 900ea814
licenses(["notice"]) # Apache 2.0
exports_files(["LICENSE"])
# Binaries
# ==============================================================================
py_binary(
name = "evaluate",
srcs = ["evaluate.py"],
deps = [
":graphs",
# google3 file dep,
# tensorflow internal dep,
],
)
py_binary(
name = "train_classifier",
srcs = ["train_classifier.py"],
deps = [
":graphs",
":train_utils",
# google3 file dep,
# tensorflow internal dep,
],
)
py_binary(
name = "pretrain",
srcs = [
"pretrain.py",
],
deps = [
":graphs",
":train_utils",
# google3 file dep,
# tensorflow internal gpu deps
# tensorflow internal dep,
],
)
# Libraries
# ==============================================================================
py_library(
name = "graphs",
srcs = ["graphs.py"],
deps = [
":adversarial_losses",
":inputs",
":layers",
# tensorflow dep,
],
)
py_library(
name = "adversarial_losses",
srcs = ["adversarial_losses.py"],
deps = [
# tensorflow dep,
],
)
py_library(
name = "inputs",
srcs = ["inputs.py"],
deps = [
# tensorflow dep,
"//adversarial_text/data:data_utils",
],
)
py_library(
name = "layers",
srcs = ["layers.py"],
deps = [
# tensorflow dep,
],
)
py_library(
name = "train_utils",
srcs = ["train_utils.py"],
deps = [
# numpy dep,
# tensorflow dep,
],
)
# Tests
# ==============================================================================
py_test(
name = "graphs_test",
size = "large",
srcs = ["graphs_test.py"],
deps = [
":graphs",
# tensorflow dep,
"//adversarial_text/data:data_utils",
],
)
......@@ -4,7 +4,6 @@ Code for [*Adversarial Training Methods for Semi-Supervised Text Classification*
## Requirements
* Bazel ([install](https://bazel.build/versions/master/docs/install.html))
* TensorFlow >= v1.1
## End-to-end IMDB Sentiment Classification
......@@ -23,7 +22,7 @@ The directory `/tmp/aclImdb` contains the raw IMDB data.
```
$ IMDB_DATA_DIR=/tmp/imdb
$ bazel run data:gen_vocab -- \
$ python gen_vocab.py -- \
--output_dir=$IMDB_DATA_DIR \
--dataset=imdb \
--imdb_input_dir=/tmp/aclImdb \
......@@ -35,7 +34,7 @@ Vocabulary and frequency files will be generated in `$IMDB_DATA_DIR`.
###  Generate training, validation, and test data
```
$ bazel run data:gen_data -- \
$ python gen_data.py -- \
--output_dir=$IMDB_DATA_DIR \
--dataset=imdb \
--imdb_input_dir=/tmp/aclImdb \
......@@ -49,7 +48,7 @@ $ bazel run data:gen_data -- \
```
$ PRETRAIN_DIR=/tmp/models/imdb_pretrain
$ bazel run :pretrain -- \
$ python pretrain.py -- \
--train_dir=$PRETRAIN_DIR \
--data_dir=$IMDB_DATA_DIR \
--vocab_size=86934 \
......@@ -77,7 +76,7 @@ training and classification.
```
$ TRAIN_DIR=/tmp/models/imdb_classify
$ bazel run :train_classifier -- \
$ python train_classifier.py -- \
--train_dir=$TRAIN_DIR \
--pretrained_model_dir=$PRETRAIN_DIR \
--data_dir=$IMDB_DATA_DIR \
......@@ -102,7 +101,7 @@ $ bazel run :train_classifier -- \
```
$ EVAL_DIR=/tmp/models/imdb_eval
$ bazel run :evaluate -- \
$ python evaluate.py -- \
--eval_dir=$EVAL_DIR \
--checkpoint_dir=$TRAIN_DIR \
--eval_data=test \
......@@ -145,8 +144,8 @@ Flags particular to each job are defined in the main binary files.
### Data Generation
* Vocabulary generation: [`gen_vocab.py`](https://github.com/tensorflow/models/tree/master/research/adversarial_text/data/gen_vocab.py)
* Data generation: [`gen_data.py`](https://github.com/tensorflow/models/tree/master/research/adversarial_text/data/gen_data.py)
* Vocabulary generation: [`gen_vocab.py`](https://github.com/tensorflow/models/tree/master/research/adversarial_text/gen_vocab.py)
* Data generation: [`gen_data.py`](https://github.com/tensorflow/models/tree/master/research/adversarial_text/gen_data.py)
Command-line flags defined in [`document_generators.py`](https://github.com/tensorflow/models/tree/master/research/adversarial_text/data/document_generators.py)
control which dataset is processed and how.
......@@ -154,4 +153,4 @@ control which dataset is processed and how.
## Contact for Issues
* Ryan Sepassi, @rsepassi
* Andrew M. Dai, @a-dai
* Andrew M. Dai, @a-dai <adai@google.com>
licenses(["notice"]) # Apache 2.0
package(
default_visibility = [
"//adversarial_text:__subpackages__",
],
)
py_binary(
name = "gen_vocab",
srcs = ["gen_vocab.py"],
deps = [
":data_utils",
":document_generators",
# tensorflow dep,
],
)
py_binary(
name = "gen_data",
srcs = ["gen_data.py"],
deps = [
":data_utils",
":document_generators",
# tensorflow dep,
],
)
py_library(
name = "document_generators",
srcs = ["document_generators.py"],
deps = [
# tensorflow dep,
],
)
py_library(
name = "data_utils",
srcs = ["data_utils.py"],
deps = [
# tensorflow dep,
],
)
py_test(
name = "data_utils_test",
srcs = ["data_utils_test.py"],
deps = [
":data_utils",
# tensorflow dep,
],
)
......@@ -21,7 +21,7 @@ from __future__ import print_function
import tensorflow as tf
from adversarial_text.data import data_utils
from data import data_utils
data = data_utils
......
......@@ -26,7 +26,7 @@ import random
import tensorflow as tf
from adversarial_text.data import data_utils
from data import data_utils
flags = tf.app.flags
FLAGS = flags.FLAGS
......
......@@ -25,7 +25,7 @@ import time
import tensorflow as tf
from adversarial_text import graphs
import graphs
flags = tf.app.flags
FLAGS = flags.FLAGS
......
......@@ -34,8 +34,8 @@ import string
import tensorflow as tf
from adversarial_text.data import data_utils
from adversarial_text.data import document_generators
from data import data_utils
from data import document_generators
data = data_utils
flags = tf.app.flags
......
......@@ -23,8 +23,8 @@ from collections import defaultdict
import tensorflow as tf
from adversarial_text.data import data_utils
from adversarial_text.data import document_generators
from data import data_utils
from data import document_generators
flags = tf.app.flags
FLAGS = flags.FLAGS
......
......@@ -24,9 +24,9 @@ import os
import tensorflow as tf
from adversarial_text import adversarial_losses as adv_lib
from adversarial_text import inputs as inputs_lib
from adversarial_text import layers as layers_lib
import adversarial_losses as adv_lib
import inputs as inputs_lib
import layers as layers_lib
flags = tf.app.flags
FLAGS = flags.FLAGS
......
......@@ -29,8 +29,8 @@ import tempfile
import tensorflow as tf
from adversarial_text import graphs
from adversarial_text.data import data_utils
import graphs
from data import data_utils
flags = tf.app.flags
FLAGS = flags.FLAGS
......
......@@ -24,7 +24,7 @@ import os
import tensorflow as tf
from adversarial_text.data import data_utils
from data import data_utils
class VatxtInput(object):
......
......@@ -254,7 +254,7 @@ def predictions(logits):
with tf.name_scope('predictions'):
# For binary classification
if inner_dim == 1:
pred = tf.cast(tf.greater(tf.squeeze(logits, -1), 0.5), tf.int64)
pred = tf.cast(tf.greater(tf.squeeze(logits, -1), 0.), tf.int64)
# For multi-class classification
else:
pred = tf.argmax(logits, 2)
......
......@@ -27,8 +27,8 @@ from __future__ import print_function
import tensorflow as tf
from adversarial_text import graphs
from adversarial_text import train_utils
import graphs
import train_utils
FLAGS = tf.app.flags.FLAGS
......
......@@ -35,8 +35,8 @@ from __future__ import print_function
import tensorflow as tf
from adversarial_text import graphs
from adversarial_text import train_utils
import graphs
import train_utils
flags = tf.app.flags
FLAGS = flags.FLAGS
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment