Unverified Commit 21f794fd authored by Katherine Wu's avatar Katherine Wu Committed by GitHub
Browse files

Export Transformer saved model, and add vocab file flag. (#4281)

parent 2d5e95a3
...@@ -20,6 +20,8 @@ The model also applies embeddings on the input and output tokens, and adds a con ...@@ -20,6 +20,8 @@ The model also applies embeddings on the input and output tokens, and adds a con
* [Translate using the model](#translate-using-the-model) * [Translate using the model](#translate-using-the-model)
* [Compute official BLEU score](#compute-official-bleu-score) * [Compute official BLEU score](#compute-official-bleu-score)
* [TPU](#tpu) * [TPU](#tpu)
* [Export trained model](#export-trained-model)
* [Example translation](#example-translation)
* [Implementation overview](#implementation-overview) * [Implementation overview](#implementation-overview)
* [Model Definition](#model-definition) * [Model Definition](#model-definition)
* [Model Estimator](#model-estimator) * [Model Estimator](#model-estimator)
...@@ -42,24 +44,26 @@ cd /path/to/models/official/transformer ...@@ -42,24 +44,26 @@ cd /path/to/models/official/transformer
PARAM_SET=big PARAM_SET=big
DATA_DIR=$HOME/transformer/data DATA_DIR=$HOME/transformer/data
MODEL_DIR=$HOME/transformer/model_$PARAM_SET MODEL_DIR=$HOME/transformer/model_$PARAM_SET
VOCAB_FILE=$DATA_DIR/vocab.ende.32768
# Download training/evaluation datasets # Download training/evaluation datasets
python data_download.py --data_dir=$DATA_DIR python data_download.py --data_dir=$DATA_DIR
# Train the model for 10 epochs, and evaluate after every epoch. # Train the model for 10 epochs, and evaluate after every epoch.
python transformer_main.py --data_dir=$DATA_DIR --model_dir=$MODEL_DIR \ python transformer_main.py --data_dir=$DATA_DIR --model_dir=$MODEL_DIR \
--param_set=$PARAM_SET --bleu_source=test_data/newstest2014.en --bleu_ref=test_data/newstest2014.de --vocab_file=$VOCAB_FILE --param_set=$PARAM_SET \
--bleu_source=test_data/newstest2014.en --bleu_ref=test_data/newstest2014.de
# Run during training in a separate process to get continuous updates, # Run during training in a separate process to get continuous updates,
# or after training is complete. # or after training is complete.
tensorboard --logdir=$MODEL_DIR tensorboard --logdir=$MODEL_DIR
# Translate some text using the trained model # Translate some text using the trained model
python translate.py --data_dir=$DATA_DIR --model_dir=$MODEL_DIR \ python translate.py --model_dir=$MODEL_DIR --vocab_file=$VOCAB_FILE \
--param_set=$PARAM_SET --text="hello world" --param_set=$PARAM_SET --text="hello world"
# Compute model's BLEU score using the newstest2014 dataset. # Compute model's BLEU score using the newstest2014 dataset.
python translate.py --data_dir=$DATA_DIR --model_dir=$MODEL_DIR \ python translate.py --model_dir=$MODEL_DIR --vocab_file=$VOCAB_FILE \
--param_set=$PARAM_SET --file=test_data/newstest2014.en --file_out=translation.en --param_set=$PARAM_SET --file=test_data/newstest2014.en --file_out=translation.en
python compute_bleu.py --translation=translation.en --reference=test_data/newstest2014.de python compute_bleu.py --translation=translation.en --reference=test_data/newstest2014.de
``` ```
...@@ -99,6 +103,7 @@ big | 28.9 ...@@ -99,6 +103,7 @@ big | 28.9
PARAM_SET=big PARAM_SET=big
DATA_DIR=$HOME/transformer/data DATA_DIR=$HOME/transformer/data
MODEL_DIR=$HOME/transformer/model_$PARAM_SET MODEL_DIR=$HOME/transformer/model_$PARAM_SET
VOCAB_FILE=$DATA_DIR/vocab.ende.32768
``` ```
1. ### Download and preprocess datasets 1. ### Download and preprocess datasets
...@@ -122,12 +127,14 @@ big | 28.9 ...@@ -122,12 +127,14 @@ big | 28.9
Command to run: Command to run:
``` ```
python transformer_main.py --data_dir=$DATA_DIR --model_dir=$MODEL_DIR --param_set=$PARAM_SET python transformer_main.py --data_dir=$DATA_DIR --model_dir=$MODEL_DIR \
--vocab_file=$VOCAB_FILE --param_set=$PARAM_SET
``` ```
Arguments: Arguments:
* `--data_dir`: This should be set to the same directory given to the `data_download`'s `data_dir` argument. * `--data_dir`: This should be set to the same directory given to the `data_download`'s `data_dir` argument.
* `--model_dir`: Directory to save Transformer model training checkpoints. * `--model_dir`: Directory to save Transformer model training checkpoints.
* `--vocab_file`: Path to subtoken vacbulary file. If data_download was used, you may find the file in `data_dir`.
* `--param_set`: Parameter set to use when creating and training the model. Options are `base` and `big` (default). * `--param_set`: Parameter set to use when creating and training the model. Options are `base` and `big` (default).
* Use the `--help` or `-h` flag to get a full list of possible arguments. * Use the `--help` or `-h` flag to get a full list of possible arguments.
...@@ -168,12 +175,13 @@ big | 28.9 ...@@ -168,12 +175,13 @@ big | 28.9
Command to run: Command to run:
``` ```
python translate.py --data_dir=$DATA_DIR --model_dir=$MODEL_DIR --param_set=PARAM_SET --text="hello world" python translate.py --model_dir=$MODEL_DIR --vocab_file=$VOCAB_FILE \
--param_set=$PARAM_SET --text="hello world"
``` ```
Arguments for initializing the Subtokenizer and trained model: Arguments for initializing the Subtokenizer and trained model:
* `--data_dir`: Used to locate the vocabulary file to create a Subtokenizer, which encodes the input and decodes the model output.
* `--model_dir` and `--param_set`: These parameters are used to rebuild the trained model * `--model_dir` and `--param_set`: These parameters are used to rebuild the trained model
* `--vocab_file`: Path to subtoken vacbulary file. If data_download was used, you may find the file in `data_dir`.
Arguments for specifying what to translate: Arguments for specifying what to translate:
* `--text`: Text to translate * `--text`: Text to translate
...@@ -182,8 +190,8 @@ big | 28.9 ...@@ -182,8 +190,8 @@ big | 28.9
To translate the newstest2014 data, run: To translate the newstest2014 data, run:
``` ```
python translate.py --data_dir=$DATA_DIR --model_dir=$MODEL_DIR \ python translate.py --model_dir=$MODEL_DIR --vocab_file=$VOCAB_FILE \
--param_set=PARAM_SET --file=test_data/newstest2014.en --file_out=translation.en --param_set=$PARAM_SET --file=test_data/newstest2014.en --file_out=translation.en
``` ```
Translating the file takes around 15 minutes on a GTX1080, or 5 minutes on a P100. Translating the file takes around 15 minutes on a GTX1080, or 5 minutes on a P100.
...@@ -200,11 +208,103 @@ big | 28.9 ...@@ -200,11 +208,103 @@ big | 28.9
* `--translation`: Path to file containing generated translations. * `--translation`: Path to file containing generated translations.
* `--reference`: Path to file containing reference translations. * `--reference`: Path to file containing reference translations.
* Use the `--help` or `-h` flag to get a full list of possible arguments. * Use the `--help` or `-h` flag to get a full list of possible arguments.
5. ### TPU 5. ### TPU
TPU support for this version of Transformer is experimental. Currently it is present for TPU support for this version of Transformer is experimental. Currently it is present for
demonstration purposes only, but will be optimized in the coming weeks. demonstration purposes only, but will be optimized in the coming weeks.
## Export trained model
To export the model as a Tensorflow [SavedModel](https://www.tensorflow.org/programmers_guide/saved_model) format, use the argument `--export_dir` when running `transformer_main.py`. A folder will be created in the directory with the name as the timestamp (e.g. $EXPORT_DIR/1526427396).
```
EXPORT_DIR=$HOME/transformer/saved_model
python transformer_main.py --data_dir=$DATA_DIR --model_dir=$MODEL_DIR \
--vocab_file=$VOCAB_FILE --param_set=$PARAM_SET --export_model=$EXPORT_DIR
```
To inspect the SavedModel, use saved_model_cli:
```
SAVED_MODEL_DIR=$EXPORT_DIR/{TIMESTAMP} # replace {TIMESTAMP} with the name of the folder created
saved_model_cli show --dir=$SAVED_MODEL_DIR --all
```
### Example translation
Let's translate **"hello world!"**, **"goodbye world."**, and **"Would you like some pie?"**.
The SignatureDef for "translate" is:
signature_def['translate']:
The given SavedModel SignatureDef contains the following input(s):
inputs['input'] tensor_info:
dtype: DT_INT64
shape: (-1, -1)
name: Placeholder:0
The given SavedModel SignatureDef contains the following output(s):
outputs['outputs'] tensor_info:
dtype: DT_INT32
shape: (-1, -1)
name: model/Transformer/strided_slice_19:0
outputs['scores'] tensor_info:
dtype: DT_FLOAT
shape: (-1)
name: model/Transformer/strided_slice_20:0
Follow the steps below to use the translate signature def:
1. #### Encode the inputs to integer arrays.
This can be done using `utils.tokenizer.Subtokenizer`, and the vocab file in the SavedModel assets (`$SAVED_MODEL_DIR/assets.extra/vocab.txt`).
```
from official.transformer.utils.tokenizer import Subtokenizer
s = Subtokenizer(PATH_TO_VOCAB_FILE)
print(s.encode("hello world!", add_eos=True))
```
The encoded inputs are:
* `"hello world!" = [6170, 3731, 178, 207, 1]`
* `"goodbye world." = [15431, 13966, 36, 178, 3, 1]`
* `"Would you like some pie?" = [9092, 72, 155, 202, 19851, 102, 1]`
2. #### Run `saved_model_cli` to obtain the predicted translations
The encoded inputs should be padded so that they are the same length. The padding token is `0`.
```
ENCODED_INPUTS="[[26228, 145, 178, 1, 0, 0, 0], \
[15431, 13966, 36, 178, 3, 1, 0], \
[9092, 72, 155, 202, 19851, 102, 1]]"
```
Now, use the `run` command with `saved_model_cli` to get the outputs.
```
saved_model_cli run --dir=$SAVED_MODEL_DIR --tag_set=serve --signature_def=translate \
--input_expr="input=$ENCODED_INPUTS"
```
The outputs will look similar to:
```
Result for output key outputs:
[[18744 145 297 1 0 0 0 0 0 0 0 0
0 0]
[ 5450 4642 21 11 297 3 1 0 0 0 0 0
0 0]
[25940 22 66 103 21713 31 102 1 0 0 0 0
0 0]]
Result for output key scores:
[-1.5493642 -1.4032784 -3.252089 ]
```
3. #### Decode the outputs to strings.
Use the `Subtokenizer` and vocab file as described in step 1 to decode the output integer arrays.
```
from official.transformer.utils.tokenizer import Subtokenizer
s = Subtokenizer(PATH_TO_VOCAB_FILE)
print(s.decode([18744, 145, 297, 1]))
```
The decoded outputs from above are:
* `[18744, 145, 297, 1] = "Hallo Welt<EOS>"`
* `[5450, 4642, 21, 11, 297, 3, 1] = "Abschied von der Welt.<EOS>"`
* `[25940, 22, 66, 103, 21713, 31, 102, 1] = "Möchten Sie einen Kuchen?<EOS>"`
## Implementation overview ## Implementation overview
A brief look at each component in the code: A brief look at each component in the code:
......
...@@ -35,7 +35,6 @@ import tensorflow as tf ...@@ -35,7 +35,6 @@ import tensorflow as tf
from official.transformer import compute_bleu from official.transformer import compute_bleu
from official.transformer import translate from official.transformer import translate
from official.transformer.data_download import VOCAB_FILE
from official.transformer.model import model_params from official.transformer.model import model_params
from official.transformer.model import transformer from official.transformer.model import transformer
from official.transformer.utils import dataset from official.transformer.utils import dataset
...@@ -43,6 +42,7 @@ from official.transformer.utils import metrics ...@@ -43,6 +42,7 @@ from official.transformer.utils import metrics
from official.transformer.utils import schedule from official.transformer.utils import schedule
from official.transformer.utils import tokenizer from official.transformer.utils import tokenizer
from official.utils.accelerator import tpu as tpu_util from official.utils.accelerator import tpu as tpu_util
from official.utils.export import export
from official.utils.flags import core as flags_core from official.utils.flags import core as flags_core
from official.utils.logs import hooks_helper from official.utils.logs import hooks_helper
from official.utils.logs import logger from official.utils.logs import logger
...@@ -55,8 +55,8 @@ PARAMS_MAP = { ...@@ -55,8 +55,8 @@ PARAMS_MAP = {
"big": model_params.BIG_PARAMS, "big": model_params.BIG_PARAMS,
} }
DEFAULT_TRAIN_EPOCHS = 10 DEFAULT_TRAIN_EPOCHS = 10
BLEU_DIR = "bleu"
INF = int(1e9) INF = int(1e9)
BLEU_DIR = "bleu"
# Dictionary containing tensors that are logged by the logging hooks. Each item # Dictionary containing tensors that are logged by the logging hooks. Each item
# maps a string to the tensor name. # maps a string to the tensor name.
...@@ -82,7 +82,10 @@ def model_fn(features, labels, mode, params): ...@@ -82,7 +82,10 @@ def model_fn(features, labels, mode, params):
raise NotImplementedError("Prediction is not yet supported on TPUs.") raise NotImplementedError("Prediction is not yet supported on TPUs.")
return tf.estimator.EstimatorSpec( return tf.estimator.EstimatorSpec(
tf.estimator.ModeKeys.PREDICT, tf.estimator.ModeKeys.PREDICT,
predictions=logits) predictions=logits,
export_outputs={
"translate": tf.estimator.export.PredictOutput(logits)
})
# Explicitly set the shape of the logits for XLA (TPU). This is needed # Explicitly set the shape of the logits for XLA (TPU). This is needed
# because the logits are passed back to the host VM CPU for metric # because the logits are passed back to the host VM CPU for metric
...@@ -218,9 +221,9 @@ def get_global_step(estimator): ...@@ -218,9 +221,9 @@ def get_global_step(estimator):
return int(estimator.latest_checkpoint().split("-")[-1]) return int(estimator.latest_checkpoint().split("-")[-1])
def evaluate_and_log_bleu(estimator, bleu_source, bleu_ref, vocab_file_path): def evaluate_and_log_bleu(estimator, bleu_source, bleu_ref, vocab_file):
"""Calculate and record the BLEU score.""" """Calculate and record the BLEU score."""
subtokenizer = tokenizer.Subtokenizer(vocab_file_path) subtokenizer = tokenizer.Subtokenizer(vocab_file)
uncased_score, cased_score = translate_and_compute_bleu( uncased_score, cased_score = translate_and_compute_bleu(
estimator, subtokenizer, bleu_source, bleu_ref) estimator, subtokenizer, bleu_source, bleu_ref)
...@@ -229,10 +232,9 @@ def evaluate_and_log_bleu(estimator, bleu_source, bleu_ref, vocab_file_path): ...@@ -229,10 +232,9 @@ def evaluate_and_log_bleu(estimator, bleu_source, bleu_ref, vocab_file_path):
tf.logging.info("Bleu score (cased):", cased_score) tf.logging.info("Bleu score (cased):", cased_score)
return uncased_score, cased_score return uncased_score, cased_score
def run_loop( def run_loop(
estimator, schedule_manager, train_hooks=None, benchmark_logger=None, estimator, schedule_manager, train_hooks=None, benchmark_logger=None,
bleu_source=None, bleu_ref=None, bleu_threshold=None, vocab_file_path=None): bleu_source=None, bleu_ref=None, bleu_threshold=None, vocab_file=None):
"""Train and evaluate model, and optionally compute model's BLEU score. """Train and evaluate model, and optionally compute model's BLEU score.
**Step vs. Epoch vs. Iteration** **Step vs. Epoch vs. Iteration**
...@@ -264,7 +266,11 @@ def run_loop( ...@@ -264,7 +266,11 @@ def run_loop(
bleu_source: File containing text to be translated for BLEU calculation. bleu_source: File containing text to be translated for BLEU calculation.
bleu_ref: File containing reference translations for BLEU calculation. bleu_ref: File containing reference translations for BLEU calculation.
bleu_threshold: minimum BLEU score before training is stopped. bleu_threshold: minimum BLEU score before training is stopped.
vocab_file_path: Path to vocabulary file used to subtokenize bleu_source. vocab_file: Path to vocab file that will be used to subtokenize bleu_source.
Raises:
ValueError: if both or none of single_iteration_train_steps and
single_iteration_train_epochs were defined.
""" """
evaluate_bleu = bleu_source is not None and bleu_ref is not None evaluate_bleu = bleu_source is not None and bleu_ref is not None
...@@ -323,7 +329,7 @@ def run_loop( ...@@ -323,7 +329,7 @@ def run_loop(
# are compared to reference file to get the actual bleu score. # are compared to reference file to get the actual bleu score.
if evaluate_bleu: if evaluate_bleu:
uncased_score, cased_score = evaluate_and_log_bleu( uncased_score, cased_score = evaluate_and_log_bleu(
estimator, bleu_source, bleu_ref, vocab_file_path) estimator, bleu_source, bleu_ref, vocab_file)
# Write actual bleu scores using summary writer and benchmark logger # Write actual bleu scores using summary writer and benchmark logger
global_step = get_global_step(estimator) global_step = get_global_step(estimator)
...@@ -347,7 +353,7 @@ def run_loop( ...@@ -347,7 +353,7 @@ def run_loop(
def define_transformer_flags(): def define_transformer_flags():
"""Add flags and flag validators for running transformer_main.""" """Add flags and flag validators for running transformer_main."""
# Add common flags (data_dir, model_dir, train_epochs, etc.). # Add common flags (data_dir, model_dir, train_epochs, etc.).
flags_core.define_base(multi_gpu=False, num_gpu=False, export_dir=False) flags_core.define_base(multi_gpu=False, num_gpu=False)
flags_core.define_performance( flags_core.define_performance(
num_parallel_calls=True, num_parallel_calls=True,
inter_op=False, inter_op=False,
...@@ -401,22 +407,23 @@ def define_transformer_flags(): ...@@ -401,22 +407,23 @@ def define_transformer_flags():
name="bleu_source", short_name="bls", default=None, name="bleu_source", short_name="bls", default=None,
help=flags_core.help_wrap( help=flags_core.help_wrap(
"Path to source file containing text translate when calculating the " "Path to source file containing text translate when calculating the "
"official BLEU score. --bleu_source, --bleu_ref, and --vocab_file " "official BLEU score. Both --bleu_source and --bleu_ref must be set. "
"must be set. Use the flag --stop_threshold to stop the script based " "Use the flag --stop_threshold to stop the script based on the "
"on the uncased BLEU score.")) "uncased BLEU score."))
flags.DEFINE_string( flags.DEFINE_string(
name="bleu_ref", short_name="blr", default=None, name="bleu_ref", short_name="blr", default=None,
help=flags_core.help_wrap( help=flags_core.help_wrap(
"Path to source file containing text translate when calculating the " "Path to source file containing text translate when calculating the "
"official BLEU score. --bleu_source, --bleu_ref, and --vocab_file " "official BLEU score. Both --bleu_source and --bleu_ref must be set. "
"must be set. Use the flag --stop_threshold to stop the script based " "Use the flag --stop_threshold to stop the script based on the "
"on the uncased BLEU score.")) "uncased BLEU score."))
flags.DEFINE_string( flags.DEFINE_string(
name="vocab_file", short_name="vf", default=VOCAB_FILE, name="vocab_file", short_name="vf", default=None,
help=flags_core.help_wrap( help=flags_core.help_wrap(
"Name of vocabulary file containing subtokens for subtokenizing the " "Path to subtoken vocabulary file. If data_download.py was used to "
"bleu_source file. This file is expected to be in the directory " "download and encode the training data, look in the data_dir to find "
"defined by --data_dir.")) "the vocab file."))
flags.mark_flag_as_required("vocab_file")
flags_core.set_defaults(data_dir="/tmp/translate_ende", flags_core.set_defaults(data_dir="/tmp/translate_ende",
model_dir="/tmp/transformer_model", model_dir="/tmp/transformer_model",
...@@ -431,20 +438,21 @@ def define_transformer_flags(): ...@@ -431,20 +438,21 @@ def define_transformer_flags():
return flag_dict["train_epochs"] is None or flag_dict["train_steps"] is None return flag_dict["train_epochs"] is None or flag_dict["train_steps"] is None
@flags.multi_flags_validator( @flags.multi_flags_validator(
["data_dir", "bleu_source", "bleu_ref", "vocab_file"], ["bleu_source", "bleu_ref"],
message="--bleu_source, --bleu_ref, and/or --vocab_file don't exist. " message="Files specified by --bleu_source and/or --bleu_ref don't exist. "
"Please ensure that the file paths are correct.") "Please ensure that the file paths are correct.")
def _check_bleu_files(flags_dict): def _check_bleu_files(flags_dict):
"""Validate files when bleu_source and bleu_ref are defined.""" """Validate files when bleu_source and bleu_ref are defined."""
if flags_dict["bleu_source"] is None or flags_dict["bleu_ref"] is None: if flags_dict["bleu_source"] is None or flags_dict["bleu_ref"] is None:
return True return True
# Ensure that bleu_source, bleu_ref, and vocab files exist.
vocab_file_path = os.path.join(
flags_dict["data_dir"], flags_dict["vocab_file"])
return all([ return all([
tf.gfile.Exists(flags_dict["bleu_source"]), tf.gfile.Exists(flags_dict["bleu_source"]),
tf.gfile.Exists(flags_dict["bleu_ref"]), tf.gfile.Exists(flags_dict["bleu_ref"])])
tf.gfile.Exists(vocab_file_path)])
@flags.validator("vocab_file", "File set by --vocab_file does not exist.")
def _check_vocab_file(vocab_file):
"""Ensure that vocab file exists."""
return tf.gfile.Exists(vocab_file)
flags_core.require_cloud_storage(["data_dir", "model_dir"]) flags_core.require_cloud_storage(["data_dir", "model_dir"])
...@@ -552,7 +560,20 @@ def run_transformer(flags_obj): ...@@ -552,7 +560,20 @@ def run_transformer(flags_obj):
bleu_source=flags_obj.bleu_source, bleu_source=flags_obj.bleu_source,
bleu_ref=flags_obj.bleu_ref, bleu_ref=flags_obj.bleu_ref,
bleu_threshold=flags_obj.stop_threshold, bleu_threshold=flags_obj.stop_threshold,
vocab_file_path=os.path.join(flags_obj.data_dir, flags_obj.vocab_file)) vocab_file=flags_obj.vocab_file)
if flags_obj.export_dir:
serving_input_fn = export.build_tensor_serving_input_receiver_fn(
shape=[None], dtype=tf.int64, batch_size=None)
# Export saved model, and save the vocab file as an extra asset. The vocab
# file is saved to allow consistent input encoding and output decoding.
# (See the "Export trained model" section in the README for an example of
# how to use the vocab file.)
# Since the model itself does not use the vocab file, this file is saved as
# an extra asset rather than a core asset.
estimator.export_savedmodel(
flags_obj.export_dir, serving_input_fn,
assets_extra={"vocab.txt": flags_obj.vocab_file})
def main(_): def main(_):
......
...@@ -27,8 +27,6 @@ from absl import flags ...@@ -27,8 +27,6 @@ from absl import flags
import tensorflow as tf import tensorflow as tf
# pylint: enable=g-bad-import-order # pylint: enable=g-bad-import-order
from official.transformer.data_download import VOCAB_FILE
from official.transformer.model import model_params
from official.transformer.utils import tokenizer from official.transformer.utils import tokenizer
from official.utils.flags import core as flags_core from official.utils.flags import core as flags_core
...@@ -163,8 +161,7 @@ def main(unused_argv): ...@@ -163,8 +161,7 @@ def main(unused_argv):
"flags --text or --file.") "flags --text or --file.")
return return
subtokenizer = tokenizer.Subtokenizer( subtokenizer = tokenizer.Subtokenizer(FLAGS.vocab_file)
os.path.join(FLAGS.data_dir, FLAGS.vocab_file))
# Set up estimator and params # Set up estimator and params
params = transformer_main.PARAMS_MAP[FLAGS.param_set] params = transformer_main.PARAMS_MAP[FLAGS.param_set]
...@@ -196,17 +193,7 @@ def main(unused_argv): ...@@ -196,17 +193,7 @@ def main(unused_argv):
def define_translate_flags(): def define_translate_flags():
"""Define flags used for translation script.""" """Define flags used for translation script."""
# Model and vocab file flags # Model flags
flags.DEFINE_string(
name="data_dir", short_name="dd", default="/tmp/translate_ende",
help=flags_core.help_wrap(
"Directory for where the translate_ende_wmt32k dataset is saved."))
flags.DEFINE_string(
name="vocab_file", short_name="vf", default=VOCAB_FILE,
help=flags_core.help_wrap(
"Name of vocabulary file containing subtokens for subtokenizing the "
"input text or file. This file is expected to be in the directory "
"defined by --data_dir."))
flags.DEFINE_string( flags.DEFINE_string(
name="model_dir", short_name="md", default="/tmp/transformer_model", name="model_dir", short_name="md", default="/tmp/transformer_model",
help=flags_core.help_wrap( help=flags_core.help_wrap(
...@@ -221,6 +208,13 @@ def define_translate_flags(): ...@@ -221,6 +208,13 @@ def define_translate_flags():
"and various other settings. The big parameter set increases the " "and various other settings. The big parameter set increases the "
"default batch size, embedding/hidden size, and filter size. For a " "default batch size, embedding/hidden size, and filter size. For a "
"complete list of parameters, please see model/model_params.py.")) "complete list of parameters, please see model/model_params.py."))
flags.DEFINE_string(
name="vocab_file", short_name="vf", default=None,
help=flags_core.help_wrap(
"Path to subtoken vocabulary file. If data_download.py was used to "
"download and encode the training data, look in the data_dir to find "
"the vocab file."))
flags.mark_flag_as_required("vocab_file")
flags.DEFINE_string( flags.DEFINE_string(
name="text", default=None, name="text", default=None,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment