"...layers/git@developer.sourcefind.cn:OpenDAS/lmdeploy.git" did not exist on "477f2db8ab5770b30f44b0c7910953f44795366e"
Commit 900ea814 authored by Andrew M. Dai's avatar Andrew M. Dai
Browse files

Move gen_data and gen_vocab to parent directory to prevent import madness in...

Move gen_data and gen_vocab to parent directory to prevent import madness in the open-source code. Remove unnecessary bazel dependency and documentation in README.

PiperOrigin-RevId: 186638059
parent b555eda9
licenses(["notice"]) # Apache 2.0
exports_files(["LICENSE"])
# Binaries
# ==============================================================================
py_binary(
name = "evaluate",
srcs = ["evaluate.py"],
deps = [
":graphs",
# google3 file dep,
# tensorflow internal dep,
],
)
py_binary(
name = "train_classifier",
srcs = ["train_classifier.py"],
deps = [
":graphs",
":train_utils",
# google3 file dep,
# tensorflow internal dep,
],
)
py_binary(
name = "pretrain",
srcs = [
"pretrain.py",
],
deps = [
":graphs",
":train_utils",
# google3 file dep,
# tensorflow internal gpu deps
# tensorflow internal dep,
],
)
# Libraries
# ==============================================================================
py_library(
name = "graphs",
srcs = ["graphs.py"],
deps = [
":adversarial_losses",
":inputs",
":layers",
# tensorflow dep,
],
)
py_library(
name = "adversarial_losses",
srcs = ["adversarial_losses.py"],
deps = [
# tensorflow dep,
],
)
py_library(
name = "inputs",
srcs = ["inputs.py"],
deps = [
# tensorflow dep,
"//adversarial_text/data:data_utils",
],
)
py_library(
name = "layers",
srcs = ["layers.py"],
deps = [
# tensorflow dep,
],
)
py_library(
name = "train_utils",
srcs = ["train_utils.py"],
deps = [
# numpy dep,
# tensorflow dep,
],
)
# Tests
# ==============================================================================
py_test(
name = "graphs_test",
size = "large",
srcs = ["graphs_test.py"],
deps = [
":graphs",
# tensorflow dep,
"//adversarial_text/data:data_utils",
],
)
...@@ -4,7 +4,6 @@ Code for [*Adversarial Training Methods for Semi-Supervised Text Classification* ...@@ -4,7 +4,6 @@ Code for [*Adversarial Training Methods for Semi-Supervised Text Classification*
## Requirements ## Requirements
* Bazel ([install](https://bazel.build/versions/master/docs/install.html))
* TensorFlow >= v1.1 * TensorFlow >= v1.1
## End-to-end IMDB Sentiment Classification ## End-to-end IMDB Sentiment Classification
...@@ -23,7 +22,7 @@ The directory `/tmp/aclImdb` contains the raw IMDB data. ...@@ -23,7 +22,7 @@ The directory `/tmp/aclImdb` contains the raw IMDB data.
``` ```
$ IMDB_DATA_DIR=/tmp/imdb $ IMDB_DATA_DIR=/tmp/imdb
$ bazel run data:gen_vocab -- \ $ python gen_vocab.py -- \
--output_dir=$IMDB_DATA_DIR \ --output_dir=$IMDB_DATA_DIR \
--dataset=imdb \ --dataset=imdb \
--imdb_input_dir=/tmp/aclImdb \ --imdb_input_dir=/tmp/aclImdb \
...@@ -35,7 +34,7 @@ Vocabulary and frequency files will be generated in `$IMDB_DATA_DIR`. ...@@ -35,7 +34,7 @@ Vocabulary and frequency files will be generated in `$IMDB_DATA_DIR`.
###  Generate training, validation, and test data ###  Generate training, validation, and test data
``` ```
$ bazel run data:gen_data -- \ $ python gen_data.py -- \
--output_dir=$IMDB_DATA_DIR \ --output_dir=$IMDB_DATA_DIR \
--dataset=imdb \ --dataset=imdb \
--imdb_input_dir=/tmp/aclImdb \ --imdb_input_dir=/tmp/aclImdb \
...@@ -49,7 +48,7 @@ $ bazel run data:gen_data -- \ ...@@ -49,7 +48,7 @@ $ bazel run data:gen_data -- \
``` ```
$ PRETRAIN_DIR=/tmp/models/imdb_pretrain $ PRETRAIN_DIR=/tmp/models/imdb_pretrain
$ bazel run :pretrain -- \ $ python pretrain.py -- \
--train_dir=$PRETRAIN_DIR \ --train_dir=$PRETRAIN_DIR \
--data_dir=$IMDB_DATA_DIR \ --data_dir=$IMDB_DATA_DIR \
--vocab_size=86934 \ --vocab_size=86934 \
...@@ -77,7 +76,7 @@ training and classification. ...@@ -77,7 +76,7 @@ training and classification.
``` ```
$ TRAIN_DIR=/tmp/models/imdb_classify $ TRAIN_DIR=/tmp/models/imdb_classify
$ bazel run :train_classifier -- \ $ python train_classifier.py -- \
--train_dir=$TRAIN_DIR \ --train_dir=$TRAIN_DIR \
--pretrained_model_dir=$PRETRAIN_DIR \ --pretrained_model_dir=$PRETRAIN_DIR \
--data_dir=$IMDB_DATA_DIR \ --data_dir=$IMDB_DATA_DIR \
...@@ -102,7 +101,7 @@ $ bazel run :train_classifier -- \ ...@@ -102,7 +101,7 @@ $ bazel run :train_classifier -- \
``` ```
$ EVAL_DIR=/tmp/models/imdb_eval $ EVAL_DIR=/tmp/models/imdb_eval
$ bazel run :evaluate -- \ $ python evaluate.py -- \
--eval_dir=$EVAL_DIR \ --eval_dir=$EVAL_DIR \
--checkpoint_dir=$TRAIN_DIR \ --checkpoint_dir=$TRAIN_DIR \
--eval_data=test \ --eval_data=test \
...@@ -145,8 +144,8 @@ Flags particular to each job are defined in the main binary files. ...@@ -145,8 +144,8 @@ Flags particular to each job are defined in the main binary files.
### Data Generation ### Data Generation
* Vocabulary generation: [`gen_vocab.py`](https://github.com/tensorflow/models/tree/master/research/adversarial_text/data/gen_vocab.py) * Vocabulary generation: [`gen_vocab.py`](https://github.com/tensorflow/models/tree/master/research/adversarial_text/gen_vocab.py)
* Data generation: [`gen_data.py`](https://github.com/tensorflow/models/tree/master/research/adversarial_text/data/gen_data.py) * Data generation: [`gen_data.py`](https://github.com/tensorflow/models/tree/master/research/adversarial_text/gen_data.py)
Command-line flags defined in [`document_generators.py`](https://github.com/tensorflow/models/tree/master/research/adversarial_text/data/document_generators.py) Command-line flags defined in [`document_generators.py`](https://github.com/tensorflow/models/tree/master/research/adversarial_text/data/document_generators.py)
control which dataset is processed and how. control which dataset is processed and how.
...@@ -154,4 +153,4 @@ control which dataset is processed and how. ...@@ -154,4 +153,4 @@ control which dataset is processed and how.
## Contact for Issues ## Contact for Issues
* Ryan Sepassi, @rsepassi * Ryan Sepassi, @rsepassi
* Andrew M. Dai, @a-dai * Andrew M. Dai, @a-dai <adai@google.com>
licenses(["notice"]) # Apache 2.0
package(
default_visibility = [
"//adversarial_text:__subpackages__",
],
)
py_binary(
name = "gen_vocab",
srcs = ["gen_vocab.py"],
deps = [
":data_utils",
":document_generators",
# tensorflow dep,
],
)
py_binary(
name = "gen_data",
srcs = ["gen_data.py"],
deps = [
":data_utils",
":document_generators",
# tensorflow dep,
],
)
py_library(
name = "document_generators",
srcs = ["document_generators.py"],
deps = [
# tensorflow dep,
],
)
py_library(
name = "data_utils",
srcs = ["data_utils.py"],
deps = [
# tensorflow dep,
],
)
py_test(
name = "data_utils_test",
srcs = ["data_utils_test.py"],
deps = [
":data_utils",
# tensorflow dep,
],
)
...@@ -21,7 +21,7 @@ from __future__ import print_function ...@@ -21,7 +21,7 @@ from __future__ import print_function
import tensorflow as tf import tensorflow as tf
from adversarial_text.data import data_utils from data import data_utils
data = data_utils data = data_utils
......
...@@ -26,7 +26,7 @@ import random ...@@ -26,7 +26,7 @@ import random
import tensorflow as tf import tensorflow as tf
import data_utils from data import data_utils
flags = tf.app.flags flags = tf.app.flags
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
......
...@@ -34,8 +34,8 @@ import string ...@@ -34,8 +34,8 @@ import string
import tensorflow as tf import tensorflow as tf
import data_utils from data import data_utils
import document_generators from data import document_generators
data = data_utils data = data_utils
flags = tf.app.flags flags = tf.app.flags
......
...@@ -23,8 +23,8 @@ from collections import defaultdict ...@@ -23,8 +23,8 @@ from collections import defaultdict
import tensorflow as tf import tensorflow as tf
import data_utils from data import data_utils
import document_generators from data import document_generators
flags = tf.app.flags flags = tf.app.flags
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
......
...@@ -29,8 +29,8 @@ import tempfile ...@@ -29,8 +29,8 @@ import tempfile
import tensorflow as tf import tensorflow as tf
from adversarial_text import graphs import graphs
from adversarial_text.data import data_utils from data import data_utils
flags = tf.app.flags flags = tf.app.flags
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment