Commit 7d30a017 authored by Ivan Bogatyy's avatar Ivan Bogatyy Committed by calberti
Browse files

Release DRAGNN (#1177)

* Release DRAGNN

* Update CoNLL evaluation table & evaluator.py
parent c774cc95
.git
bazel/
Dockerfile*
tensorflow/.git
# Java baseimage, for Bazel.
FROM java:8 FROM java:8
ENV SYNTAXNETDIR=/opt/tensorflow PATH=$PATH:/root/bin ENV SYNTAXNETDIR=/opt/tensorflow PATH=$PATH:/root/bin
# Install system packages. This doesn't include everything the TensorFlow
# dockerfile specifies, so if anything goes awry, maybe install more packages
# from there. Also, running apt-get clean before further commands will make the
# Docker images smaller.
RUN mkdir -p $SYNTAXNETDIR \ RUN mkdir -p $SYNTAXNETDIR \
&& cd $SYNTAXNETDIR \ && cd $SYNTAXNETDIR \
&& apt-get update \ && apt-get update \
&& apt-get install git zlib1g-dev file swig python2.7 python-dev python-pip python-mock -y \ && apt-get install -y \
&& pip install --upgrade pip \ file \
&& pip install -U protobuf==3.0.0b2 \ git \
&& pip install asciitree \ graphviz \
&& pip install numpy \ libcurl3-dev \
&& wget https://github.com/bazelbuild/bazel/releases/download/0.4.3/bazel-0.4.3-installer-linux-x86_64.sh \ libfreetype6-dev \
libgraphviz-dev \
liblapack-dev \
libopenblas-dev \
libpng12-dev \
libxft-dev \
python-dev \
python-mock \
python-pip \
python2.7 \
swig \
vim \
zlib1g-dev \
&& apt-get clean \
&& (rm -f /var/cache/apt/archives/*.deb \
/var/cache/apt/archives/partial/*.deb /var/cache/apt/*.bin || true)
# Install common Python dependencies. Similar to above, remove caches
# afterwards to help keep Docker images smaller.
RUN pip install --ignore-installed pip \
&& python -m pip install numpy \
&& rm -rf /root/.cache/pip /tmp/pip*
RUN python -m pip install \
asciitree \
ipykernel \
jupyter \
matplotlib \
pandas \
protobuf \
scipy \
sklearn \
&& python -m ipykernel.kernelspec \
&& python -m pip install pygraphviz \
--install-option="--include-path=/usr/include/graphviz" \
--install-option="--library-path=/usr/lib/graphviz/" \
&& rm -rf /root/.cache/pip /tmp/pip*
# Installs the latest version of Bazel.
RUN wget --quiet https://github.com/bazelbuild/bazel/releases/download/0.4.3/bazel-0.4.3-installer-linux-x86_64.sh \
&& chmod +x bazel-0.4.3-installer-linux-x86_64.sh \ && chmod +x bazel-0.4.3-installer-linux-x86_64.sh \
&& ./bazel-0.4.3-installer-linux-x86_64.sh --user \ && ./bazel-0.4.3-installer-linux-x86_64.sh \
&& git clone --recursive https://github.com/tensorflow/models.git \ && rm ./bazel-0.4.3-installer-linux-x86_64.sh
&& cd $SYNTAXNETDIR/models/syntaxnet/tensorflow \
&& echo -e "\n\n\n\n\n\n\n\n\n" | ./configure \ COPY WORKSPACE $SYNTAXNETDIR/syntaxnet/WORKSPACE
&& apt-get autoremove -y \ COPY tools/bazel.rc $SYNTAXNETDIR/syntaxnet/tools/bazel.rc
&& apt-get clean COPY tensorflow $SYNTAXNETDIR/syntaxnet/tensorflow
# Compile common TensorFlow targets, which don't depend on DRAGNN / SyntaxNet
# source. This makes it more convenient to re-compile DRAGNN / SyntaxNet for
# development (though not as convenient as the docker-devel scripts).
RUN cd $SYNTAXNETDIR/syntaxnet/tensorflow \
&& tensorflow/tools/ci_build/builds/configured CPU \
&& cd $SYNTAXNETDIR/syntaxnet \
&& bazel build -c opt @org_tensorflow//tensorflow:tensorflow_py
RUN cd $SYNTAXNETDIR/models/syntaxnet \ # Build the codez.
&& bazel test --genrule_strategy=standalone syntaxnet/... util/utf8/... WORKDIR $SYNTAXNETDIR/syntaxnet
COPY dragnn $SYNTAXNETDIR/syntaxnet/dragnn
COPY syntaxnet $SYNTAXNETDIR/syntaxnet/syntaxnet
COPY third_party $SYNTAXNETDIR/syntaxnet/third_party
COPY util/utf8 $SYNTAXNETDIR/syntaxnet/util/utf8
RUN bazel build -c opt //dragnn/python:all //dragnn/tools:all
WORKDIR $SYNTAXNETDIR/models/syntaxnet # This makes the IP exposed actually "*"; we'll do host restrictions by passing
# a hostname to the `docker run` command.
COPY tensorflow/tensorflow/tools/docker/jupyter_notebook_config.py /root/.jupyter/
EXPOSE 8888
CMD [ "sh", "-c", "echo 'Bob brought the pizza to Alice.' | syntaxnet/demo.sh" ] # This does not need to be compiled, only copied.
COPY examples $SYNTAXNETDIR/syntaxnet/examples
# Todo: Move this earlier in the file (don't want to invalidate caches for now).
RUN jupyter nbextension enable --py --sys-prefix widgetsnbextension
# COMMANDS to build and run CMD /bin/bash -c "bazel-bin/dragnn/tools/oss_notebook_launcher notebook --debug --notebook-dir=/opt/tensorflow/syntaxnet/examples"
# ===============================
# mkdir build && cp Dockerfile build/ && cd build
# docker build -t syntaxnet .
# docker run syntaxnet
# SyntaxNet: Neural Models of Syntax. # SyntaxNet: Neural Models of Syntax.
*A TensorFlow implementation of the models described in [Andor et al. (2016)] *A TensorFlow toolkit for deep learning powered natural language understanding
(http://arxiv.org/abs/1603.06042).* (NLU).*
**Update**: Parsey models are now [available](universal.md) for 40 languages **CoNLL**: See [here](g3doc/conll2017/README.md) for instructions for using the
trained on Universal Dependencies datasets, with support for text segmentation SyntaxNet/DRAGNN baseline for the CoNLL2017 Shared Task.
and morphological analysis.
At Google, we spend a lot of time thinking about how computer systems can read At Google, we spend a lot of time thinking about how computer systems can read
and understand human language in order to process it in intelligent ways. We are and understand human language in order to process it in intelligent ways. We are
excited to share the fruits of our research with the broader community by excited to share the fruits of our research with the broader community by
releasing SyntaxNet, an open-source neural network framework for [TensorFlow] releasing SyntaxNet, an open-source neural network framework for
(http://www.tensorflow.org) that provides a foundation for Natural Language [TensorFlow](http://www.tensorflow.org) that provides a foundation for Natural
Understanding (NLU) systems. Our release includes all the code needed to train Language Understanding (NLU) systems. Our release includes all the code needed
new SyntaxNet models on your own data, as well as *Parsey McParseface*, an to train new SyntaxNet models on your own data, as well as a suite of models
English parser that we have trained for you, and that you can use to analyze that we have trained for you, and that you can use to analyze text in over 40
English text. languages.
So, how accurate is Parsey McParseface? For this release, we tried to balance a This repository is largely divided into two sub-packages:
model that runs fast enough to be useful on a single machine (e.g. ~600
words/second on a modern desktop) and that is also the most accurate parser 1. **DRAGNN:
available. Here's how Parsey McParseface compares to the academic literature on [code](https://github.com/tensorflow/models/tree/master/syntaxnet/dragnn),
several different English domains: (all numbers are % correct head assignments [documentation](g3doc/DRAGNN.md)** implements Dynamic Recurrent Acyclic
in the tree, or unlabelled attachment score) Graphical Neural Networks (DRAGNN), a framework for building multi-task,
fully dynamic constructed computation graphs. Practically, we use DRAGNN to
Model | News | Web | Questions extend our prior work from [Andor et al.
--------------------------------------------------------------------------------------------------------------- | :---: | :---: | :-------: (2016)](http://arxiv.org/abs/1603.06042) with end-to-end, deep recurrent
[Martins et al. (2013)](http://www.cs.cmu.edu/~ark/TurboParser/) | 93.10 | 88.23 | 94.21 models and to provide a much easier to use interface to SyntaxNet.
[Zhang and McDonald (2014)](http://research.google.com/pubs/archive/38148.pdf) | 93.32 | 88.65 | 93.37 1. **SyntaxNet:
[Weiss et al. (2015)](http://static.googleusercontent.com/media/research.google.com/en//pubs/archive/43800.pdf) | 93.91 | 89.29 | 94.17 [code](https://github.com/tensorflow/models/tree/master/syntaxnet/syntaxnet),
[Andor et al. (2016)](http://arxiv.org/abs/1603.06042)* | 94.44 | 90.17 | 95.40 [documentation](g3doc/syntaxnet-tutorial.md)** is a transition-based
Parsey McParseface | 94.15 | 89.08 | 94.77 framework for natural language processing, with core functionality for
feature extraction, representing annotated data, and evaluation. As of the
We see that Parsey McParseface is state-of-the-art; more importantly, with DRAGNN release, it is recommended to train and deploy SyntaxNet models using
SyntaxNet you can train larger networks with more hidden units and bigger beam the DRAGNN framework.
sizes if you want to push the accuracy even further: [Andor et al. (2016)]
(http://arxiv.org/abs/1603.06042)* is simply a SyntaxNet model with a ## How to use this library
larger beam and network. For futher information on the datasets, see that paper
under the section "Treebank Union". There are three ways to use SyntaxNet:
* See [here](g3doc/conll2017/README.md) for instructions for using the
SyntaxNet/DRAGNN baseline for the CoNLL2017 Shared Task, and running the
ParseySaurus models.
* You can use DRAGNN to train your NLP models for other tasks and dataset. See
"Getting started with DRAGNN below."
* You can continue to use the Parsey McParseface family of pre-trained
SyntaxNet models. See "Pre-trained NLP models" below.
Parsey McParseface is also state-of-the-art for part-of-speech (POS) tagging ## Installation
(numbers below are per-token accuracy):
Model | News | Web | Questions ### Docker installation
-------------------------------------------------------------------------- | :---: | :---: | :-------:
[Ling et al. (2015)](http://www.cs.cmu.edu/~lingwang/papers/emnlp2015.pdf) | 97.44 | 94.03 | 96.18
[Andor et al. (2016)](http://arxiv.org/abs/1603.06042)* | 97.77 | 94.80 | 96.86
Parsey McParseface | 97.52 | 94.24 | 96.45
The first part of this tutorial describes how to install the necessary tools and The simplest way to get started with DRAGNN is by loading our Docker container.
use the already trained models provided in this release. In the second part of [Here](g3doc/CLOUD.md) is a tutorial for running the DRAGNN container on
the tutorial we provide more background about the models, as well as [GCP](https://cloud.google.com) (just as applicable to your own computer).
instructions for training models on other datasets.
## Contents
* [Installation](#installation)
* [Getting Started](#getting-started)
* [Parsing from Standard Input](#parsing-from-standard-input)
* [Annotating a Corpus](#annotating-a-corpus)
* [Configuring the Python Scripts](#configuring-the-python-scripts)
* [Next Steps](#next-steps)
* [Detailed Tutorial: Building an NLP Pipeline with SyntaxNet](#detailed-tutorial-building-an-nlp-pipeline-with-syntaxnet)
* [Obtaining Data](#obtaining-data)
* [Part-of-Speech Tagging](#part-of-speech-tagging)
* [Training the SyntaxNet POS Tagger](#training-the-syntaxnet-pos-tagger)
* [Preprocessing with the Tagger](#preprocessing-with-the-tagger)
* [Dependency Parsing: Transition-Based Parsing](#dependency-parsing-transition-based-parsing)
* [Training a Parser Step 1: Local Pretraining](#training-a-parser-step-1-local-pretraining)
* [Training a Parser Step 2: Global Training](#training-a-parser-step-2-global-training)
* [Contact](#contact)
* [Credits](#credits)
## Installation ### Manual installation
Running and training SyntaxNet models requires building this package from Running and training SyntaxNet/DRAGNN models requires building this package from
source. You'll need to install: source. You'll need to install:
* python 2.7: * python 2.7:
* python 3 support is not available yet * Python 3 support is not available yet
* bazel: * bazel:
* **version 0.4.3** * Follow the instructions [here](http://bazel.build/docs/install.html)
* follow the instructions [here](http://bazel.build/docs/install.html) * Alternately, Download bazel <.deb> from
* Alternately, Download bazel (0.4.3) <.deb> from [https://github.com/bazelbuild/bazel/releases](https://github.com/bazelbuild/bazel/releases)
[https://github.com/bazelbuild/bazel/releases] for your system configuration.
(https://github.com/bazelbuild/bazel/releases) for your system
configuration.
* Install it using the command: sudo dpkg -i <.deb file> * Install it using the command: sudo dpkg -i <.deb file>
* Check for the bazel version by typing: bazel version * Check for the bazel version by typing: bazel version
* swig: * swig:
...@@ -99,6 +80,11 @@ source. You'll need to install: ...@@ -99,6 +80,11 @@ source. You'll need to install:
* `pip install asciitree` * `pip install asciitree`
* numpy, package for scientific computing: * numpy, package for scientific computing:
* `pip install numpy` * `pip install numpy`
* pygraphviz to visualize traces and parse trees:
* `apt-get install -y graphviz libgraphviz-dev`
* `pip install pygraphviz
--install-option="--include-path=/usr/include/graphviz"
--install-option="--library-path=/usr/lib/graphviz/"`
Once you completed the above steps, you can build and test SyntaxNet with the Once you completed the above steps, you can build and test SyntaxNet with the
following commands: following commands:
...@@ -108,17 +94,14 @@ following commands: ...@@ -108,17 +94,14 @@ following commands:
cd models/syntaxnet/tensorflow cd models/syntaxnet/tensorflow
./configure ./configure
cd .. cd ..
bazel test syntaxnet/... util/utf8/... bazel test ...
# On Mac, run the following: # On Mac, run the following:
bazel test --linkopt=-headerpad_max_install_names \ bazel test --linkopt=-headerpad_max_install_names \
syntaxnet/... util/utf8/... dragnn/... syntaxnet/... util/utf8/...
``` ```
Bazel should complete reporting all tests passed. Bazel should complete reporting all tests passed.
You can also compile SyntaxNet in a [Docker](https://www.docker.com/what-docker)
container using this [Dockerfile](Dockerfile).
To build SyntaxNet with GPU support please refer to the instructions in To build SyntaxNet with GPU support please refer to the instructions in
[issues/248](https://github.com/tensorflow/models/issues/248). [issues/248](https://github.com/tensorflow/models/issues/248).
...@@ -127,12 +110,64 @@ memory allocated for your Docker VM. ...@@ -127,12 +110,64 @@ memory allocated for your Docker VM.
## Getting Started ## Getting Started
We have a few guides on this README, as well as more extensive
[documentation](g3doc/).
### Learning the DRAGNN framework
![DRAGNN](g3doc/unrolled-dragnn.png)
An easy and visual way to get started with DRAGNN is to run [our Jupyter
Notebook](examples/dragnn/basic_parser_tutorial.ipynb). Our tutorial
[here](g3doc/CLOUD.md) explains how to start it up from the Docker container.
### Using the Pre-trained NLP models
We are happy to release *Parsey McParseface*, an English parser that we have
trained for you, and that you can use to analyze English text, along with
[trained models for 40 languages](g3doc/universal.md) and support for text
segmentation and morphological analysis.
Once you have successfully built SyntaxNet, you can start parsing text right Once you have successfully built SyntaxNet, you can start parsing text right
away with Parsey McParseface, located under `syntaxnet/models`. The easiest away with Parsey McParseface, located under `syntaxnet/models`. The easiest
thing is to use or modify the included script `syntaxnet/demo.sh`, which shows a thing is to use or modify the included script `syntaxnet/demo.sh`, which shows a
basic setup to parse English taking plain text as input. basic setup to parse English taking plain text as input.
### Parsing from Standard Input You can also skip right away to the [detailed SyntaxNet
tutorial](g3doc/syntaxnet-tutorial.md).
How accurate is Parsey McParseface? For the initial release, we tried to balance
a model that runs fast enough to be useful on a single machine (e.g. ~600
words/second on a modern desktop) and that is also the most accurate parser
available. Here's how Parsey McParseface compares to the academic literature on
several different English domains: (all numbers are % correct head assignments
in the tree, or unlabelled attachment score)
Model | News | Web | Questions
--------------------------------------------------------------------------------------------------------------- | :---: | :---: | :-------:
[Martins et al. (2013)](http://www.cs.cmu.edu/~ark/TurboParser/) | 93.10 | 88.23 | 94.21
[Zhang and McDonald (2014)](http://research.google.com/pubs/archive/38148.pdf) | 93.32 | 88.65 | 93.37
[Weiss et al. (2015)](http://static.googleusercontent.com/media/research.google.com/en//pubs/archive/43800.pdf) | 93.91 | 89.29 | 94.17
[Andor et al. (2016)](http://arxiv.org/abs/1603.06042)* | 94.44 | 90.17 | 95.40
Parsey McParseface | 94.15 | 89.08 | 94.77
We see that Parsey McParseface is state-of-the-art; more importantly, with
SyntaxNet you can train larger networks with more hidden units and bigger beam
sizes if you want to push the accuracy even further: [Andor et al.
(2016)](http://arxiv.org/abs/1603.06042)* is simply a SyntaxNet model with a
larger beam and network. For futher information on the datasets, see that paper
under the section "Treebank Union".
Parsey McParseface is also state-of-the-art for part-of-speech (POS) tagging
(numbers below are per-token accuracy):
Model | News | Web | Questions
-------------------------------------------------------------------------- | :---: | :---: | :-------:
[Ling et al. (2015)](http://www.cs.cmu.edu/~lingwang/papers/emnlp2015.pdf) | 97.44 | 94.03 | 96.18
[Andor et al. (2016)](http://arxiv.org/abs/1603.06042)* | 97.77 | 94.80 | 96.86
Parsey McParseface | 97.52 | 94.24 | 96.45
#### Parsing from Standard Input
Simply pass one sentence per line of text into the script at Simply pass one sentence per line of text into the script at
`syntaxnet/demo.sh`. The script will break the text into words, run the POS `syntaxnet/demo.sh`. The script will break the text into words, run the POS
...@@ -160,7 +195,7 @@ visualized in our tutorial graphs. In this example, we see that the verb ...@@ -160,7 +195,7 @@ visualized in our tutorial graphs. In this example, we see that the verb
If you want to feed in tokenized, CONLL-formatted text, you can run `demo.sh If you want to feed in tokenized, CONLL-formatted text, you can run `demo.sh
--conll`. --conll`.
### Annotating a Corpus #### Annotating a Corpus
To change the pipeline to read and write to specific files (as opposed to piping To change the pipeline to read and write to specific files (as opposed to piping
through stdin and stdout), we have to modify the `demo.sh` to point to the files through stdin and stdout), we have to modify the `demo.sh` to point to the files
...@@ -200,7 +235,7 @@ input { ...@@ -200,7 +235,7 @@ input {
Then we can use `--input=wsj-data --output=wsj-data-tagged` on the command line Then we can use `--input=wsj-data --output=wsj-data-tagged` on the command line
to specify reading and writing to these files. to specify reading and writing to these files.
### Configuring the Python Scripts #### Configuring the Python Scripts
As mentioned above, the python scripts are configured in two ways: As mentioned above, the python scripts are configured in two ways:
...@@ -234,386 +269,13 @@ There are many ways to extend this framework, e.g. adding new features, changing ...@@ -234,386 +269,13 @@ There are many ways to extend this framework, e.g. adding new features, changing
the model structure, training on other languages, etc. We suggest reading the the model structure, training on other languages, etc. We suggest reading the
detailed tutorial below to get a handle on the rest of the framework. detailed tutorial below to get a handle on the rest of the framework.
## Detailed Tutorial: Building an NLP Pipeline with SyntaxNet
In this tutorial, we'll go over how to train new models, and explain in a bit
more technical detail the NLP side of the models. Our goal here is to explain
the NLP pipeline produced by this package.
### Obtaining Data
The included English parser, Parsey McParseface, was trained on the the standard
corpora of the [Penn Treebank](https://catalog.ldc.upenn.edu/LDC99T42) and
[OntoNotes](https://catalog.ldc.upenn.edu/LDC2013T19), as well as the [English
Web Treebank](https://catalog.ldc.upenn.edu/LDC2012T13), but these are
unfortunately not freely available.
However, the [Universal Dependencies](http://universaldependencies.org/) project
provides freely available treebank data in a number of languages. SyntaxNet can
be trained and evaluated on any of these corpora.
### Part-of-Speech Tagging
Consider the following sentence, which exhibits several ambiguities that affect
its interpretation:
> I saw the man with glasses.
This sentence is composed of words: strings of characters that are segmented
into groups (e.g. "I", "saw", etc.) Each word in the sentence has a *grammatical
function* that can be useful for understanding the meaning of language. For
example, "saw" in this example is a past tense of the verb "to see". But any
given word might have different meanings in different contexts: "saw" could just
as well be a noun (e.g., a saw used for cutting) or a present tense verb (using
a saw to cut something).
A logical first step in understanding language is figuring out these roles for
each word in the sentence. This process is called *Part-of-Speech (POS)
Tagging*. The roles are called POS tags. Although a given word might have
multiple possible tags depending on the context, given any one interpretation of
a sentence each word will generally only have one tag.
One interesting challenge of POS tagging is that the problem of defining a
vocabulary of POS tags for a given language is quite involved. While the concept
of nouns and verbs is pretty common, it has been traditionally difficult to
agree on a standard set of roles across all languages. The [Universal
Dependencies](http://www.universaldependencies.org) project aims to solve this
problem.
### Training the SyntaxNet POS Tagger
In general, determining the correct POS tag requires understanding the entire
sentence and the context in which it is uttered. In practice, we can do very
well just by considering a small window of words around the word of interest.
For example, words that follow the word ‘the’ tend to be adjectives or nouns,
rather than verbs.
To predict POS tags, we use a simple setup. We process the sentences
left-to-right. For any given word, we extract features of that word and a window
around it, and use these as inputs to a feed-forward neural network classifier,
which predicts a probability distribution over POS tags. Because we make
decisions in left-to-right order, we also use prior decisions as features in
subsequent ones (e.g. "the previous predicted tag was a noun.").
All the models in this package use a flexible markup language to define
features. For example, the features in the POS tagger are found in the
`brain_pos_features` parameter in the `TaskSpec`, and look like this (modulo
spacing):
```
stack(3).word stack(2).word stack(1).word stack.word input.word input(1).word input(2).word input(3).word;
input.digit input.hyphen;
stack.suffix(length=2) input.suffix(length=2) input(1).suffix(length=2);
stack.prefix(length=2) input.prefix(length=2) input(1).prefix(length=2)
```
Note that `stack` here means "words we have already tagged." Thus, this feature
spec uses three types of features: words, suffixes, and prefixes. The features
are grouped into blocks that share an embedding matrix, concatenated together,
and fed into a chain of hidden layers. This structure is based upon the model
proposed by [Chen and Manning (2014)]
(http://cs.stanford.edu/people/danqi/papers/emnlp2014.pdf).
We show this layout in the schematic below: the state of the system (a stack and
a buffer, visualized below for both the POS and the dependency parsing task) is
used to extract sparse features, which are fed into the network in groups. We
show only a small subset of the features to simplify the presentation in the
schematic:
![Schematic](ff_nn_schematic.png "Feed-forward Network Structure")
In the configuration above, each block gets its own embedding matrix and the
blocks in the configuration above are delineated with a semi-colon. The
dimensions of each block are controlled in the `brain_pos_embedding_dims`
parameter. **Important note:** unlike many simple NLP models, this is *not* a
bag of words model. Remember that although certain features share embedding
matrices, the above features will be concatenated, so the interpretation of
`input.word` will be quite different from `input(1).word`. This also means that
adding features increases the dimension of the `concat` layer of the model as
well as the number of parameters for the first hidden layer.
To train the model, first edit `syntaxnet/context.pbtxt` so that the inputs
`training-corpus`, `tuning-corpus`, and `dev-corpus` point to the location of
your training data. You can then train a part-of-speech tagger with:
```shell
bazel-bin/syntaxnet/parser_trainer \
--task_context=syntaxnet/context.pbtxt \
--arg_prefix=brain_pos \ # read from POS configuration
--compute_lexicon \ # required for first stage of pipeline
--graph_builder=greedy \ # no beam search
--training_corpus=training-corpus \ # names of training/tuning set
--tuning_corpus=tuning-corpus \
--output_path=models \ # where to save new resources
--batch_size=32 \ # Hyper-parameters
--decay_steps=3600 \
--hidden_layer_sizes=128 \
--learning_rate=0.08 \
--momentum=0.9 \
--seed=0 \
--params=128-0.08-3600-0.9-0 # name for these parameters
```
This will read in the data, construct a lexicon, build a tensorflow graph for
the model with the specific hyperparameters, and train the model. Every so often
the model will be evaluated on the tuning set, and only the checkpoint with the
highest accuracy on this set will be saved. **Note that you should never use a
corpus you intend to test your model on as your tuning set, as you will inflate
your test set results.**
For best results, you should repeat this command with at least 3 different
seeds, and possibly with a few different values for `--learning_rate` and
`--decay_steps`. Good values for `--learning_rate` are usually close to 0.1, and
you usually want `--decay_steps` to correspond to about one tenth of your
corpus. The `--params` flag is only a human readable identifier for the model
being trained, used to construct the full output path, so that you don't need to
worry about clobbering old models by accident.
The `--arg_prefix` flag controls which parameters should be read from the task
context file `context.pbtxt`. In this case `arg_prefix` is set to `brain_pos`,
so the paramters being used in this training run are
`brain_pos_transition_system`, `brain_pos_embedding_dims`, `brain_pos_features`
and, `brain_pos_embedding_names`. To train the dependency parser later
`arg_prefix` will be set to `brain_parser`.
### Preprocessing with the Tagger
Now that we have a trained POS tagging model, we want to use the output of this
model as features in the parser. Thus the next step is to run the trained model
over our training, tuning, and dev (evaluation) sets. We can use the
parser_eval.py` script for this.
For example, the model `128-0.08-3600-0.9-0` trained above can be run over the
training, tuning, and dev sets with the following command:
```shell
PARAMS=128-0.08-3600-0.9-0
for SET in training tuning dev; do
bazel-bin/syntaxnet/parser_eval \
--task_context=models/brain_pos/greedy/$PARAMS/context \
--hidden_layer_sizes=128 \
--input=$SET-corpus \
--output=tagged-$SET-corpus \
--arg_prefix=brain_pos \
--graph_builder=greedy \
--model_path=models/brain_pos/greedy/$PARAMS/model
done
```
**Important note:** This command only works because we have created entries for
you in `context.pbtxt` that correspond to `tagged-training-corpus`,
`tagged-dev-corpus`, and `tagged-tuning-corpus`. From these default settings,
the above will write tagged versions of the training, tuning, and dev set to the
directory `models/brain_pos/greedy/$PARAMS/`. This location is chosen because
the `input` entries do not have `file_pattern` set: instead, they have `creator:
brain_pos/greedy`, which means that `parser_trainer.py` will construct *new*
files when called with `--arg_prefix=brain_pos --graph_builder=greedy` using the
`--model_path` flag to determine the location.
For convenience, `parser_eval.py` also logs POS tagging accuracy after the
output tagged datasets have been written.
### Dependency Parsing: Transition-Based Parsing
Now that we have a prediction for the grammatical role of the words, we want to
understand how the words in the sentence relate to each other. This parser is
built around the *head-modifier* construction: for each word, we choose a
*syntactic head* that it modifies according to some grammatical role.
An example for the above sentence is as follows:
![Figure](sawman.png)
Below each word in the sentence we see both a fine-grained part-of-speech
(*PRP*, *VBD*, *DT*, *NN* etc.), and a coarse-grained part-of-speech (*PRON*,
*VERB*, *DET*, *NOUN*, etc.). Coarse-grained POS tags encode basic grammatical
categories, while the fine-grained POS tags make further distinctions: for
example *NN* is a singular noun (as opposed, for example, to *NNS*, which is a
plural noun), and *VBD* is a past-tense verb. For more discussion see [Petrov et
al. (2012)](http://www.lrec-conf.org/proceedings/lrec2012/pdf/274_Paper.pdf).
Crucially, we also see directed arcs signifying grammatical relationships
between different words in the sentence. For example *I* is the subject of
*saw*, as signified by the directed arc labeled *nsubj* between these words;
*man* is the direct object (dobj) of *saw*; the preposition *with* modifies
*man* with a prep relation, signifiying modification by a prepositional phrase;
and so on. In addition the verb *saw* is identified as the *root* of the entire
sentence.
Whenever we have a directed arc between two words, we refer to the word at the
start of the arc as the *head*, and the word at the end of the arc as the
*modifier*. For example we have one arc where the head is *saw* and the modifier
is *I*, another where the head is *saw* and the modifier is *man*, and so on.
The grammatical relationships encoded in dependency structures are directly
related to the underlying meaning of the sentence in question. They allow us to
easily recover the answers to various questions, for example *whom did I see?*,
*who saw the man with glasses?*, and so on.
SyntaxNet is a **transition-based** dependency parser [Nivre (2007)]
(http://www.mitpressjournals.org/doi/pdfplus/10.1162/coli.07-056-R1-07-027) that
constructs a parse incrementally. Like the tagger, it processes words
left-to-right. The words all start as unprocessed input, called the *buffer*. As
words are encountered they are put onto a *stack*. At each step, the parser can
do one of three things:
1. **SHIFT:** Push another word onto the top of the stack, i.e. shifting one
token from the buffer to the stack.
1. **LEFT_ARC:** Pop the top two words from the stack. Attach the second to the
first, creating an arc pointing to the **left**. Push the **first** word
back on the stack.
1. **RIGHT_ARC:** Pop the top two words from the stack. Attach the second to
the first, creating an arc point to the **right**. Push the **second** word
back on the stack.
At each step, we call the combination of the stack and the buffer the
*configuration* of the parser. For the left and right actions, we also assign a
dependency relation label to that arc. This process is visualized in the
following animation for a short sentence:
![Animation](looping-parser.gif "Parsing in Action")
Note that this parser is following a sequence of actions, called a
**derivation**, to produce a "gold" tree labeled by a linguist. We can use this
sequence of decisions to learn a classifier that takes a configuration and
predicts the next action to take.
### Training a Parser Step 1: Local Pretraining
As described in our [paper](http://arxiv.org/abs/1603.06042), the first
step in training the model is to *pre-train* using *local* decisions. In this
phase, we use the gold dependency to guide the parser, and train a softmax layer
to predict the correct action given these gold dependencies. This can be
performed very efficiently, since the parser's decisions are all independent in
this setting.
Once the tagged datasets are available, a locally normalized dependency parsing
model can be trained with the following command:
```shell
bazel-bin/syntaxnet/parser_trainer \
--arg_prefix=brain_parser \
--batch_size=32 \
--projectivize_training_set \
--decay_steps=4400 \
--graph_builder=greedy \
--hidden_layer_sizes=200,200 \
--learning_rate=0.08 \
--momentum=0.85 \
--output_path=models \
--task_context=models/brain_pos/greedy/$PARAMS/context \
--seed=4 \
--training_corpus=tagged-training-corpus \
--tuning_corpus=tagged-tuning-corpus \
--params=200x200-0.08-4400-0.85-4
```
Note that we point the trainer to the context corresponding to the POS tagger
that we picked previously. This allows the parser to reuse the lexicons and the
tagged datasets that were created in the previous steps. Processing data can be
done similarly to how tagging was done above. For example if in this case we
picked parameters `200x200-0.08-4400-0.85-4`, the training, tuning and dev sets
can be parsed with the following command:
```shell
PARAMS=200x200-0.08-4400-0.85-4
for SET in training tuning dev; do
bazel-bin/syntaxnet/parser_eval \
--task_context=models/brain_parser/greedy/$PARAMS/context \
--hidden_layer_sizes=200,200 \
--input=tagged-$SET-corpus \
--output=parsed-$SET-corpus \
--arg_prefix=brain_parser \
--graph_builder=greedy \
--model_path=models/brain_parser/greedy/$PARAMS/model
done
```
### Training a Parser Step 2: Global Training
As we describe in the paper, there are several problems with the locally
normalized models we just trained. The most important is the *label-bias*
problem: the model doesn't learn what a good parse looks like, only what action
to take given a history of gold decisions. This is because the scores are
normalized *locally* using a softmax for each decision.
In the paper, we show how we can achieve much better results using a *globally*
normalized model: in this model, the softmax scores are summed in log space, and
the scores are not normalized until we reach a final decision. When the parser
stops, the scores of each hypothesis are normalized against a small set of
possible parses (in the case of this model, a beam size of 8). When training, we
force the parser to stop during parsing when the gold derivation falls off the
beam (a strategy known as early-updates).
We give a simplified view of how this training works for a [garden path
sentence](https://en.wikipedia.org/wiki/Garden_path_sentence), where it is
important to maintain multiple hypotheses. A single mistake early on in parsing
leads to a completely incorrect parse; after training, the model learns to
prefer the second (correct) parse.
![Beam search training](beam_search_training.png)
Parsey McParseface correctly parses this sentence. Even though the correct parse
is initially ranked 4th out of multiple hypotheses, when the end of the garden
path is reached, Parsey McParseface can recover due to the beam; using a larger
beam will get a more accurate model, but it will be slower (we used beam 32 for
the models in the paper).
Once you have the pre-trained locally normalized model, a globally normalized
parsing model can now be trained with the following command:
```shell
bazel-bin/syntaxnet/parser_trainer \
--arg_prefix=brain_parser \
--batch_size=8 \
--decay_steps=100 \
--graph_builder=structured \
--hidden_layer_sizes=200,200 \
--learning_rate=0.02 \
--momentum=0.9 \
--output_path=models \
--task_context=models/brain_parser/greedy/$PARAMS/context \
--seed=0 \
--training_corpus=projectivized-training-corpus \
--tuning_corpus=tagged-tuning-corpus \
--params=200x200-0.02-100-0.9-0 \
--pretrained_params=models/brain_parser/greedy/$PARAMS/model \
--pretrained_params_names=\
embedding_matrix_0,embedding_matrix_1,embedding_matrix_2,\
bias_0,weights_0,bias_1,weights_1
```
Training a beam model with the structured builder will take a lot longer than
the greedy training runs above, perhaps 3 or 4 times longer. Note once again
that multiple restarts of training will yield the most reliable results.
Evaluation can again be done with `parser_eval.py`. In this case we use
parameters `200x200-0.02-100-0.9-0` to evaluate on the training, tuning and dev
sets with the following command:
```shell
PARAMS=200x200-0.02-100-0.9-0
for SET in training tuning dev; do
bazel-bin/syntaxnet/parser_eval \
--task_context=models/brain_parser/structured/$PARAMS/context \
--hidden_layer_sizes=200,200 \
--input=tagged-$SET-corpus \
--output=beam-parsed-$SET-corpus \
--arg_prefix=brain_parser \
--graph_builder=structured \
--model_path=models/brain_parser/structured/$PARAMS/model
done
```
Hooray! You now have your very own cousin of Parsey McParseface, ready to go out
and parse text in the wild.
## Contact ## Contact
To ask questions or report issues please post on Stack Overflow with the tag To ask questions or report issues please post on Stack Overflow with the tag
[syntaxnet](http://stackoverflow.com/questions/tagged/syntaxnet) [syntaxnet](http://stackoverflow.com/questions/tagged/syntaxnet) or open an
or open an issue on the tensorflow/models issue on the tensorflow/models [issues
[issues tracker](https://github.com/tensorflow/models/issues). tracker](https://github.com/tensorflow/models/issues). Please assign SyntaxNet
Please assign SyntaxNet issues to @calberti or @andorardo. issues to @calberti or @andorardo.
## Credits ## Credits
...@@ -633,6 +295,7 @@ Original authors of the code in this package include (in alphabetical order): ...@@ -633,6 +295,7 @@ Original authors of the code in this package include (in alphabetical order):
* Keith Hall * Keith Hall
* Kuzman Ganchev * Kuzman Ganchev
* Livio Baldini Soares * Livio Baldini Soares
* Mark Omernick
* Michael Collins * Michael Collins
* Michael Ringgaard * Michael Ringgaard
* Ryan McDonald * Ryan McDonald
...@@ -640,3 +303,4 @@ Original authors of the code in this package include (in alphabetical order): ...@@ -640,3 +303,4 @@ Original authors of the code in this package include (in alphabetical order):
* Stefan Istrate * Stefan Istrate
* Terry Koo * Terry Koo
* Tim Credo * Tim Credo
* Zora Tung
...@@ -3,10 +3,23 @@ local_repository( ...@@ -3,10 +3,23 @@ local_repository(
path = "tensorflow", path = "tensorflow",
) )
# We need to pull in @io_bazel_rules_closure for TensorFlow. Bazel design
# documentation states that this verbosity is intentional, to prevent
# TensorFlow/SyntaxNet from depending on different versions of
# @io_bazel_rules_closure.
http_archive(
name = "io_bazel_rules_closure",
sha256 = "60fc6977908f999b23ca65698c2bb70213403824a84f7904310b6000d78be9ce",
strip_prefix = "rules_closure-5ca1dab6df9ad02050f7ba4e816407f88690cf7d",
urls = [
"http://bazel-mirror.storage.googleapis.com/github.com/bazelbuild/rules_closure/archive/5ca1dab6df9ad02050f7ba4e816407f88690cf7d.tar.gz", # 2017-02-03
"https://github.com/bazelbuild/rules_closure/archive/5ca1dab6df9ad02050f7ba4e816407f88690cf7d.tar.gz",
],
)
load("@org_tensorflow//tensorflow:workspace.bzl", "tf_workspace") load("@org_tensorflow//tensorflow:workspace.bzl", "tf_workspace")
tf_workspace(path_prefix="", tf_repo_name="org_tensorflow") tf_workspace(path_prefix="", tf_repo_name="org_tensorflow")
# Test that Bazel is up-to-date. # Test that Bazel is up-to-date.
load("@org_tensorflow//tensorflow:workspace.bzl", "check_version") load("@org_tensorflow//tensorflow:workspace.bzl", "check_version")
check_version("0.4.3") check_version("0.4.2")
#!/bin/bash
#
# This file puts you in a Docker sub-shell where you can build SyntaxNet
# targets. It is intended for development, as the Dockerfile (build file) does
# not actually build any of SyntaxNet, but instead mounts it in a volume.
script_path="$(readlink -f "$0")"
root_path="$(dirname "$(dirname "${script_path}")")"
set -e
if [[ -z "$(docker images -q dragnn-oss)" ]]; then
docker build -t dragnn-oss .
else
echo "NOTE: dragnn-oss image already exists, not re-building." >&2
echo "Please run \`docker build -t dragnn-oss .\` if you need." >&2
fi
echo -e "\n\nRun bazel commands like \`bazel test syntaxnet/...\`"
# NOTE: Unfortunately, we need to mount /tensorflow over /syntaxnet/tensorflow
# (which happens via devel_entrypoint.sh). This requires privileged mode.
syntaxnet_base="/opt/tensorflow/syntaxnet"
docker run --rm -ti \
-v "${root_path}"/syntaxnet:"${syntaxnet_base}"/syntaxnet \
-v "${root_path}"/dragnn:"${syntaxnet_base}"/dragnn \
-p 127.0.0.1:8888:8888 \
dragnn-oss "$@"
#!/bin/bash
#
# Convenience script to build wheel files in Docker, and copy them out of the
# container.
#
# Usage: docker-devel/build_wheels.sh (takes no arguments; run it from the base
# directory).
set -e
docker build -t dragnn-oss .
# Start building the wheels.
script="bazel run //dragnn/tools:build_pip_package \
-- --output-dir=/opt/tensorflow/syntaxnet; \
bazel run //dragnn/tools:build_pip_package \
-- --output-dir=/opt/tensorflow/syntaxnet --include-tensorflow"
container_id="$(docker run -d dragnn-oss /bin/bash -c "${script}")"
echo "Waiting for container ${container_id} to finish building the wheel ..."
if [[ "$(docker wait "${container_id}")" != 0 ]]; then
echo "Container failed! Please run \`docker logs <id>\` to see errors." >&2
exit 1
fi
# The build_pip_package.py script prints lines like "Wrote x.whl". The wheel
# names are prefixed by architecture and such, so don't guess them.
wheels=(
$(docker logs "${container_id}" 2>/dev/null | grep Wrote | awk '{print $2;}'))
for wheel in "${wheels[@]}"; do
output=./"$(basename "${wheel}")"
docker cp "${container_id}:${wheel}" "${output}"
echo "Wrote ${output} ($(du -h "${output}" | awk '{print $1;}'))"
done
echo "Removing ${container_id} ..."
docker rm "${container_id}" >/dev/null
package_group(
name = "dragnn_visibility",
packages = [
],
)
package(default_visibility = ["//visibility:public"])
cc_library(
name = "syntaxnet_component",
srcs = ["syntaxnet_component.cc"],
hdrs = ["syntaxnet_component.h"],
deps = [
":syntaxnet_link_feature_extractor",
":syntaxnet_transition_state",
"//dragnn/components/util:bulk_feature_extractor",
"//dragnn/core:beam",
"//dragnn/core:component_registry",
"//dragnn/core:input_batch_cache",
"//dragnn/core/interfaces:component",
"//dragnn/core/interfaces:transition_state",
"//dragnn/io:sentence_input_batch",
"//dragnn/io:syntaxnet_sentence",
"//dragnn/protos:data_proto",
"//dragnn/protos:spec_proto",
"//dragnn/protos:trace_proto",
"//syntaxnet:base",
"//syntaxnet:parser_transitions",
"//syntaxnet:registry",
"//syntaxnet:sparse_proto",
"//syntaxnet:task_context",
"//syntaxnet:task_spec_proto",
"//syntaxnet:utils",
"@org_tensorflow//tensorflow/core:lib", # For tf/core/platform/logging.h
],
alwayslink = 1,
)
cc_library(
name = "syntaxnet_link_feature_extractor",
srcs = ["syntaxnet_link_feature_extractor.cc"],
hdrs = ["syntaxnet_link_feature_extractor.h"],
deps = [
"//dragnn/protos:spec_proto",
"//syntaxnet:embedding_feature_extractor",
"//syntaxnet:parser_transitions",
"//syntaxnet:task_context",
"@org_tensorflow//tensorflow/core:lib", # For tf/core/platform/logging.h
],
)
cc_library(
name = "syntaxnet_transition_state",
srcs = ["syntaxnet_transition_state.cc"],
hdrs = ["syntaxnet_transition_state.h"],
deps = [
"//dragnn/core/interfaces:cloneable_transition_state",
"//dragnn/core/interfaces:transition_state",
"//dragnn/io:syntaxnet_sentence",
"//dragnn/protos:trace_proto",
"//syntaxnet:base",
"//syntaxnet:parser_transitions",
"@org_tensorflow//tensorflow/core:lib", # For tf/core/platform/logging.h
],
)
# Test data.
filegroup(
name = "testdata",
data = glob(["testdata/**"]),
)
# Tests.
cc_test(
name = "syntaxnet_component_test",
srcs = ["syntaxnet_component_test.cc"],
data = [":testdata"],
deps = [
":syntaxnet_component",
"//dragnn/core:input_batch_cache",
"//dragnn/core/test:generic",
"//dragnn/core/test:mock_transition_state",
"//dragnn/io:sentence_input_batch",
"//syntaxnet:sentence_proto",
"@org_tensorflow//tensorflow/core:lib",
"@org_tensorflow//tensorflow/core:test",
],
)
cc_test(
name = "syntaxnet_link_feature_extractor_test",
srcs = ["syntaxnet_link_feature_extractor_test.cc"],
deps = [
":syntaxnet_link_feature_extractor",
"//dragnn/core/test:generic",
"//dragnn/protos:spec_proto",
"//syntaxnet:task_context",
"//syntaxnet:test_main",
"@org_tensorflow//tensorflow/core:test",
"@org_tensorflow//tensorflow/core:testlib",
],
)
cc_test(
name = "syntaxnet_transition_state_test",
srcs = ["syntaxnet_transition_state_test.cc"],
data = [":testdata"],
deps = [
":syntaxnet_component",
":syntaxnet_transition_state",
"//dragnn/core:input_batch_cache",
"//dragnn/core/test:generic",
"//dragnn/core/test:mock_transition_state",
"//dragnn/io:sentence_input_batch",
"//dragnn/protos:spec_proto",
"//syntaxnet:sentence_proto",
"//syntaxnet:test_main",
"@org_tensorflow//tensorflow/core:lib",
"@org_tensorflow//tensorflow/core:test",
"@org_tensorflow//tensorflow/core:testlib",
],
)
#include "dragnn/components/syntaxnet/syntaxnet_component.h"
#include <vector>
#include "dragnn/components/util/bulk_feature_extractor.h"
#include "dragnn/core/component_registry.h"
#include "dragnn/core/input_batch_cache.h"
#include "dragnn/core/interfaces/component.h"
#include "dragnn/core/interfaces/transition_state.h"
#include "dragnn/io/sentence_input_batch.h"
#include "dragnn/io/syntaxnet_sentence.h"
#include "syntaxnet/parser_state.h"
#include "syntaxnet/sparse.pb.h"
#include "syntaxnet/task_spec.pb.h"
#include "syntaxnet/utils.h"
#include "tensorflow/core/platform/logging.h"
namespace syntaxnet {
namespace dragnn {
using tensorflow::strings::StrCat;
namespace {
// Returns a new step in a trace based on a ComponentSpec.
ComponentStepTrace GetNewStepTrace(const ComponentSpec &spec,
const TransitionState &state) {
ComponentStepTrace step;
for (auto &linked_spec : spec.linked_feature()) {
auto &channel_trace = *step.add_linked_feature_trace();
channel_trace.set_name(linked_spec.name());
channel_trace.set_source_component(linked_spec.source_component());
channel_trace.set_source_translator(linked_spec.source_translator());
channel_trace.set_source_layer(linked_spec.source_layer());
}
for (auto &fixed_spec : spec.fixed_feature()) {
step.add_fixed_feature_trace()->set_name(fixed_spec.name());
}
step.set_html_representation(state.HTMLRepresentation());
return step;
}
// Returns the last step in the trace.
ComponentStepTrace *GetLastStepInTrace(ComponentTrace *trace) {
CHECK_GT(trace->step_trace_size(), 0) << "Trace has no steps added yet";
return trace->mutable_step_trace(trace->step_trace_size() - 1);
}
} // anonymous namespace
SyntaxNetComponent::SyntaxNetComponent()
: feature_extractor_("brain_parser"),
rewrite_root_labels_(false),
max_beam_size_(1),
input_data_(nullptr) {}
void SyntaxNetComponent::InitializeComponent(const ComponentSpec &spec) {
// Save off the passed spec for future reference.
spec_ = spec;
// Create and populate a TaskContext for the underlying parser.
TaskContext context;
// Add the specified resources.
for (const Resource &resource : spec_.resource()) {
auto *input = context.GetInput(resource.name());
for (const Part &part : resource.part()) {
auto *input_part = input->add_part();
input_part->set_file_pattern(part.file_pattern());
input_part->set_file_format(part.file_format());
input_part->set_record_format(part.record_format());
}
}
// Add the specified task args to the transition system.
for (const auto &param : spec_.transition_system().parameters()) {
context.SetParameter(param.first, param.second);
}
// Set the arguments for the feature extractor.
std::vector<string> names;
std::vector<string> dims;
std::vector<string> fml;
std::vector<string> predicate_maps;
for (const FixedFeatureChannel &channel : spec.fixed_feature()) {
names.push_back(channel.name());
fml.push_back(channel.fml());
predicate_maps.push_back(channel.predicate_map());
dims.push_back(StrCat(channel.embedding_dim()));
}
context.SetParameter("neurosis_feature_syntax_version", "2");
context.SetParameter("brain_parser_embedding_dims", utils::Join(dims, ";"));
context.SetParameter("brain_parser_predicate_maps",
utils::Join(predicate_maps, ";"));
context.SetParameter("brain_parser_features", utils::Join(fml, ";"));
context.SetParameter("brain_parser_embedding_names", utils::Join(names, ";"));
names.clear();
dims.clear();
fml.clear();
predicate_maps.clear();
std::vector<string> source_components;
std::vector<string> source_layers;
std::vector<string> source_translators;
for (const LinkedFeatureChannel &channel : spec.linked_feature()) {
names.push_back(channel.name());
fml.push_back(channel.fml());
dims.push_back(StrCat(channel.embedding_dim()));
source_components.push_back(channel.source_component());
source_layers.push_back(channel.source_layer());
source_translators.push_back(channel.source_translator());
predicate_maps.push_back("none");
}
context.SetParameter("link_embedding_dims", utils::Join(dims, ";"));
context.SetParameter("link_predicate_maps", utils::Join(predicate_maps, ";"));
context.SetParameter("link_features", utils::Join(fml, ";"));
context.SetParameter("link_embedding_names", utils::Join(names, ";"));
context.SetParameter("link_source_layers", utils::Join(source_layers, ";"));
context.SetParameter("link_source_translators",
utils::Join(source_translators, ";"));
context.SetParameter("link_source_components",
utils::Join(source_components, ";"));
context.SetParameter("parser_transition_system",
spec.transition_system().registered_name());
// Set up the fixed feature extractor.
feature_extractor_.Setup(&context);
feature_extractor_.Init(&context);
feature_extractor_.RequestWorkspaces(&workspace_registry_);
// Set up the underlying transition system.
transition_system_.reset(ParserTransitionSystem::Create(
context.Get("parser_transition_system", "arc-standard")));
transition_system_->Setup(&context);
transition_system_->Init(&context);
// Create label map.
string path = TaskContext::InputFile(*context.GetInput("label-map"));
label_map_ =
SharedStoreUtils::GetWithDefaultName<TermFrequencyMap>(path, 0, 0);
// Set up link feature extractors.
if (spec.linked_feature_size() > 0) {
link_feature_extractor_.Setup(&context);
link_feature_extractor_.Init(&context);
link_feature_extractor_.RequestWorkspaces(&workspace_registry_);
}
// Get the legacy flag for simulating old parser processor behavior. If the
// flag is not set, default to 'false'.
rewrite_root_labels_ = context.Get("rewrite_root_labels", false);
}
std::unique_ptr<Beam<SyntaxNetTransitionState>> SyntaxNetComponent::CreateBeam(
int max_size) {
std::unique_ptr<Beam<SyntaxNetTransitionState>> beam(
new Beam<SyntaxNetTransitionState>(max_size));
auto permission_function = [this](SyntaxNetTransitionState *state,
int action) {
VLOG(3) << "permission_function action:" << action
<< " is_allowed:" << this->IsAllowed(state, action);
return this->IsAllowed(state, action);
};
auto finality_function = [this](SyntaxNetTransitionState *state) {
VLOG(2) << "finality_function is_final:" << this->IsFinal(state);
return this->IsFinal(state);
};
auto oracle_function = [this](SyntaxNetTransitionState *state) {
VLOG(2) << "oracle_function action:" << this->GetOracleLabel(state);
return this->GetOracleLabel(state);
};
auto beam_ptr = beam.get();
auto advance_function = [this, beam_ptr](SyntaxNetTransitionState *state,
int action) {
VLOG(2) << "advance_function beam ptr:" << beam_ptr << " action:" << action;
this->Advance(state, action, beam_ptr);
};
beam->SetFunctions(permission_function, finality_function, advance_function,
oracle_function);
return beam;
}
void SyntaxNetComponent::InitializeData(
const std::vector<std::vector<const TransitionState *>> &parent_states,
int max_beam_size, InputBatchCache *input_data) {
// Save off the input data object.
input_data_ = input_data;
// If beam size has changed, change all beam sizes for existing beams.
if (max_beam_size_ != max_beam_size) {
CHECK_GT(max_beam_size, 0)
<< "Requested max beam size must be greater than 0.";
VLOG(2) << "Adjusting max beam size from " << max_beam_size_ << " to "
<< max_beam_size;
max_beam_size_ = max_beam_size;
for (auto &beam : batch_) {
beam->SetMaxSize(max_beam_size_);
}
}
SentenceInputBatch *sentences = input_data->GetAs<SentenceInputBatch>();
// Expect that the sentence data is the same size as the input states batch.
if (!parent_states.empty()) {
CHECK_EQ(parent_states.size(), sentences->data()->size());
}
// Adjust the beam vector so that it is the correct size for this batch.
if (batch_.size() < sentences->data()->size()) {
VLOG(1) << "Batch size is increased to " << sentences->data()->size()
<< " from " << batch_.size();
for (int i = batch_.size(); i < sentences->data()->size(); ++i) {
batch_.push_back(CreateBeam(max_beam_size));
}
} else if (batch_.size() > sentences->data()->size()) {
VLOG(1) << "Batch size is decreased to " << sentences->data()->size()
<< " from " << batch_.size();
batch_.erase(batch_.begin() + sentences->data()->size(), batch_.end());
} else {
VLOG(1) << "Batch size is constant at " << sentences->data()->size();
}
CHECK_EQ(batch_.size(), sentences->data()->size());
// Fill the beams with the relevant data for that batch.
for (int batch_index = 0; batch_index < sentences->data()->size();
++batch_index) {
// Create a vector of states for this component's beam.
std::vector<std::unique_ptr<SyntaxNetTransitionState>> initial_states;
if (parent_states.empty()) {
// If no states have been passed in, create a single state to seed the
// beam.
initial_states.push_back(
CreateState(&(sentences->data()->at(batch_index))));
} else {
// If states have been passed in, seed the beam with them up to the max
// beam size.
int num_states =
std::min(batch_.at(batch_index)->max_size(),
static_cast<int>(parent_states.at(batch_index).size()));
VLOG(2) << "Creating a beam using " << num_states << " initial states";
for (int i = 0; i < num_states; ++i) {
std::unique_ptr<SyntaxNetTransitionState> state(
CreateState(&(sentences->data()->at(batch_index))));
state->Init(*parent_states.at(batch_index).at(i));
initial_states.push_back(std::move(state));
}
}
batch_.at(batch_index)->Init(std::move(initial_states));
}
}
bool SyntaxNetComponent::IsReady() const { return input_data_ != nullptr; }
string SyntaxNetComponent::Name() const {
return "SyntaxNet-backed beam parser";
}
int SyntaxNetComponent::BatchSize() const { return batch_.size(); }
int SyntaxNetComponent::BeamSize() const { return max_beam_size_; }
int SyntaxNetComponent::StepsTaken(int batch_index) const {
return batch_.at(batch_index)->num_steps();
}
int SyntaxNetComponent::GetBeamIndexAtStep(int step, int current_index,
int batch) const {
return batch_.at(batch)->FindPreviousIndex(current_index, step);
}
int SyntaxNetComponent::GetSourceBeamIndex(int current_index, int batch) const {
return batch_.at(batch)->FindPreviousIndex(current_index, 0);
}
std::function<int(int, int, int)> SyntaxNetComponent::GetStepLookupFunction(
const string &method) {
if (method == "shift-reduce-step") {
// TODO(googleuser): Describe this function.
return [this](int batch_index, int beam_index, int value) {
SyntaxNetTransitionState *state =
batch_.at(batch_index)->beam_state(beam_index);
return state->step_for_token(value);
};
} else if (method == "reduce-step") {
// TODO(googleuser): Describe this function.
return [this](int batch_index, int beam_index, int value) {
SyntaxNetTransitionState *state =
batch_.at(batch_index)->beam_state(beam_index);
return state->parent_step_for_token(value);
};
} else if (method == "parent-shift-reduce-step") {
// TODO(googleuser): Describe this function.
return [this](int batch_index, int beam_index, int value) {
SyntaxNetTransitionState *state =
batch_.at(batch_index)->beam_state(beam_index);
return state->step_for_token(state->parent_step_for_token(value));
};
} else if (method == "reverse-token") {
// TODO(googleuser): Describe this function.
return [this](int batch_index, int beam_index, int value) {
SyntaxNetTransitionState *state =
batch_.at(batch_index)->beam_state(beam_index);
int result = state->sentence()->sentence()->token_size() - value - 1;
if (result >= 0 && result < state->sentence()->sentence()->token_size()) {
return result;
} else {
return -1;
}
};
} else {
LOG(FATAL) << "Unable to find step lookup function " << method;
}
}
void SyntaxNetComponent::AdvanceFromPrediction(const float transition_matrix[],
int transition_matrix_length) {
VLOG(2) << "Advancing from prediction.";
int matrix_index = 0;
int num_labels = transition_system_->NumActions(label_map_->Size());
for (int i = 0; i < batch_.size(); ++i) {
int max_beam_size = batch_.at(i)->max_size();
int matrix_size = num_labels * max_beam_size;
CHECK_LE(matrix_index + matrix_size, transition_matrix_length);
if (!batch_.at(i)->IsTerminal()) {
batch_.at(i)->AdvanceFromPrediction(&transition_matrix[matrix_index],
matrix_size, num_labels);
}
matrix_index += num_labels * max_beam_size;
}
}
void SyntaxNetComponent::AdvanceFromOracle() {
VLOG(2) << "Advancing from oracle.";
for (auto &beam : batch_) {
beam->AdvanceFromOracle();
}
}
bool SyntaxNetComponent::IsTerminal() const {
VLOG(2) << "Checking terminal status.";
for (const auto &beam : batch_) {
if (!beam->IsTerminal()) {
return false;
}
}
return true;
}
std::vector<std::vector<const TransitionState *>>
SyntaxNetComponent::GetBeam() {
std::vector<std::vector<const TransitionState *>> state_beam;
for (auto &beam : batch_) {
// Because this component only finalizes the data of the highest ranked
// component in each beam, the next component should only be initialized
// from the highest ranked component in that beam.
state_beam.push_back({beam->beam().at(0)});
}
return state_beam;
}
int SyntaxNetComponent::GetFixedFeatures(
std::function<int32 *(int)> allocate_indices,
std::function<int64 *(int)> allocate_ids,
std::function<float *(int)> allocate_weights, int channel_id) const {
std::vector<SparseFeatures> features;
const int channel_size = spec_.fixed_feature(channel_id).size();
// For every beam in the batch...
for (const auto &beam : batch_) {
// For every element in the beam...
for (int beam_idx = 0; beam_idx < beam->size(); ++beam_idx) {
// Get the SparseFeatures from the feature extractor.
auto state = beam->beam_state(beam_idx);
const std::vector<std::vector<SparseFeatures>> sparse_features =
feature_extractor_.ExtractSparseFeatures(
*(state->sentence()->workspace()), *(state->parser_state()));
// Hold the SparseFeatures for later processing.
for (const SparseFeatures &f : sparse_features[channel_id]) {
features.emplace_back(f);
if (do_tracing_) {
FixedFeatures fixed_features;
for (const string &name : f.description()) {
fixed_features.add_value_name(name);
}
fixed_features.set_feature_name("");
auto *trace = GetLastStepInTrace(state->mutable_trace());
auto *fixed_trace = trace->mutable_fixed_feature_trace(channel_id);
*fixed_trace->add_value_trace() = fixed_features;
}
}
}
const int pad_amount = max_beam_size_ - beam->size();
features.resize(features.size() + pad_amount * channel_size);
}
int feature_count = 0;
for (const auto &feature : features) {
feature_count += feature.id_size();
}
VLOG(2) << "Feature count is " << feature_count;
int32 *indices_tensor = allocate_indices(feature_count);
int64 *ids_tensor = allocate_ids(feature_count);
float *weights_tensor = allocate_weights(feature_count);
int array_index = 0;
for (int feature_index = 0; feature_index < features.size();
++feature_index) {
VLOG(2) << "Extracting for feature_index " << feature_index;
const auto feature = features[feature_index];
for (int sub_idx = 0; sub_idx < feature.id_size(); ++sub_idx) {
indices_tensor[array_index] = feature_index;
ids_tensor[array_index] = feature.id(sub_idx);
if (sub_idx < feature.weight_size()) {
weights_tensor[array_index] = feature.weight(sub_idx);
} else {
weights_tensor[array_index] = 1.0;
}
VLOG(2) << "Feature index: " << indices_tensor[array_index]
<< " id: " << ids_tensor[array_index]
<< " weight: " << weights_tensor[array_index];
++array_index;
}
}
return feature_count;
}
int SyntaxNetComponent::BulkGetFixedFeatures(
const BulkFeatureExtractor &extractor) {
// Allocate a vector of SparseFeatures per channel.
const int num_channels = spec_.fixed_feature_size();
std::vector<int> channel_size(num_channels);
for (int i = 0; i < num_channels; ++i) {
channel_size[i] = spec_.fixed_feature(i).size();
}
std::vector<std::vector<SparseFeatures>> features(num_channels);
std::vector<std::vector<int>> feature_indices(num_channels);
std::vector<std::vector<int>> step_indices(num_channels);
std::vector<std::vector<int>> element_indices(num_channels);
std::vector<int> feature_counts(num_channels);
int step_count = 0;
while (!IsTerminal()) {
int current_element = 0;
// For every beam in the batch...
for (const auto &beam : batch_) {
// For every element in the beam...
for (int beam_idx = 0; beam_idx < beam->size(); ++beam_idx) {
// Get the SparseFeatures from the parser.
auto state = beam->beam_state(beam_idx);
const std::vector<std::vector<SparseFeatures>> sparse_features =
feature_extractor_.ExtractSparseFeatures(
*(state->sentence()->workspace()), *(state->parser_state()));
for (int channel_id = 0; channel_id < num_channels; ++channel_id) {
int feature_count = 0;
for (const SparseFeatures &f : sparse_features[channel_id]) {
// Trace, if requested.
if (do_tracing_) {
FixedFeatures fixed_features;
for (const string &name : f.description()) {
fixed_features.add_value_name(name);
}
fixed_features.set_feature_name("");
auto *trace = GetLastStepInTrace(state->mutable_trace());
auto *fixed_trace =
trace->mutable_fixed_feature_trace(channel_id);
*fixed_trace->add_value_trace() = fixed_features;
}
// Hold the SparseFeatures for later processing.
features[channel_id].emplace_back(f);
element_indices[channel_id].emplace_back(current_element);
step_indices[channel_id].emplace_back(step_count);
feature_indices[channel_id].emplace_back(feature_count);
feature_counts[channel_id] += f.id_size();
++feature_count;
}
}
++current_element;
}
// Advance the current element to skip unused beam slots.
// Pad the beam out to max_beam_size.
int pad_amount = max_beam_size_ - beam->size();
current_element += pad_amount;
}
AdvanceFromOracle();
++step_count;
}
const int total_steps = step_count;
const int num_elements = batch_.size() * max_beam_size_;
// This would be a good place to add threading.
for (int channel_id = 0; channel_id < num_channels; ++channel_id) {
int feature_count = feature_counts[channel_id];
LOG(INFO) << "Feature count is " << feature_count << " for channel "
<< channel_id;
int32 *indices_tensor =
extractor.AllocateIndexMemory(channel_id, feature_count);
int64 *ids_tensor = extractor.AllocateIdMemory(channel_id, feature_count);
float *weights_tensor =
extractor.AllocateWeightMemory(channel_id, feature_count);
int array_index = 0;
for (int feat_idx = 0; feat_idx < features[channel_id].size(); ++feat_idx) {
const auto &feature = features[channel_id][feat_idx];
int element_index = element_indices[channel_id][feat_idx];
int step_index = step_indices[channel_id][feat_idx];
int feature_index = feature_indices[channel_id][feat_idx];
for (int sub_idx = 0; sub_idx < feature.id_size(); ++sub_idx) {
indices_tensor[array_index] =
extractor.GetIndex(total_steps, num_elements, feature_index,
element_index, step_index);
ids_tensor[array_index] = feature.id(sub_idx);
if (sub_idx < feature.weight_size()) {
weights_tensor[array_index] = feature.weight(sub_idx);
} else {
weights_tensor[array_index] = 1.0;
}
++array_index;
}
}
}
return step_count;
}
std::vector<LinkFeatures> SyntaxNetComponent::GetRawLinkFeatures(
int channel_id) const {
std::vector<LinkFeatures> features;
const int channel_size = spec_.linked_feature(channel_id).size();
std::unique_ptr<std::vector<string>> feature_names;
if (do_tracing_) {
feature_names.reset(new std::vector<string>);
*feature_names = utils::Split(spec_.linked_feature(channel_id).fml(), ' ');
}
// For every beam in the batch...
for (int batch_idx = 0; batch_idx < batch_.size(); ++batch_idx) {
// For every element in the beam...
const auto &beam = batch_[batch_idx];
for (int beam_idx = 0; beam_idx < beam->size(); ++beam_idx) {
// Get the raw link features from the linked feature extractor.
auto state = beam->beam_state(beam_idx);
std::vector<FeatureVector> raw_features(
link_feature_extractor_.NumEmbeddings());
link_feature_extractor_.ExtractFeatures(*(state->sentence()->workspace()),
*(state->parser_state()),
&raw_features);
// Add the raw feature values to the LinkFeatures proto.
CHECK_LT(channel_id, raw_features.size());
for (int i = 0; i < raw_features[channel_id].size(); ++i) {
features.emplace_back();
features.back().set_feature_value(raw_features[channel_id].value(i));
features.back().set_batch_idx(batch_idx);
features.back().set_beam_idx(beam_idx);
if (do_tracing_) {
features.back().set_feature_name(feature_names->at(i));
}
}
}
// Pad the beam out to max_beam_size.
int pad_amount = max_beam_size_ - beam->size();
features.resize(features.size() + pad_amount * channel_size);
}
return features;
}
std::vector<std::vector<int>> SyntaxNetComponent::GetOracleLabels() const {
std::vector<std::vector<int>> oracle_labels;
for (const auto &beam : batch_) {
oracle_labels.emplace_back();
for (int beam_idx = 0; beam_idx < beam->size(); ++beam_idx) {
// Get the raw link features from the linked feature extractor.
auto state = beam->beam_state(beam_idx);
oracle_labels.back().push_back(GetOracleLabel(state));
}
}
return oracle_labels;
}
void SyntaxNetComponent::FinalizeData() {
// This chooses the top-scoring member of the beam to annotate the underlying
// document.
VLOG(2) << "Finalizing data.";
for (auto &beam : batch_) {
if (beam->size() != 0) {
auto top_state = beam->beam_state(0);
VLOG(3) << "Finalizing for sentence: "
<< top_state->sentence()->sentence()->ShortDebugString();
top_state->parser_state()->AddParseToDocument(
top_state->sentence()->sentence(), rewrite_root_labels_);
VLOG(3) << "Sentence is now: "
<< top_state->sentence()->sentence()->ShortDebugString();
} else {
LOG(WARNING) << "Attempting to finalize an empty beam for component "
<< spec_.name();
}
}
}
void SyntaxNetComponent::ResetComponent() {
for (auto &beam : batch_) {
beam->Reset();
}
input_data_ = nullptr;
max_beam_size_ = 0;
}
std::unique_ptr<SyntaxNetTransitionState> SyntaxNetComponent::CreateState(
SyntaxNetSentence *sentence) {
VLOG(3) << "Creating state for sentence "
<< sentence->sentence()->DebugString();
std::unique_ptr<ParserState> parser_state(new ParserState(
sentence->sentence(), transition_system_->NewTransitionState(false),
label_map_));
sentence->workspace()->Reset(workspace_registry_);
feature_extractor_.Preprocess(sentence->workspace(), parser_state.get());
link_feature_extractor_.Preprocess(sentence->workspace(), parser_state.get());
std::unique_ptr<SyntaxNetTransitionState> transition_state(
new SyntaxNetTransitionState(std::move(parser_state), sentence));
return transition_state;
}
bool SyntaxNetComponent::IsAllowed(SyntaxNetTransitionState *state,
int action) const {
return transition_system_->IsAllowedAction(action, *(state->parser_state()));
}
bool SyntaxNetComponent::IsFinal(SyntaxNetTransitionState *state) const {
return transition_system_->IsFinalState(*(state->parser_state()));
}
int SyntaxNetComponent::GetOracleLabel(SyntaxNetTransitionState *state) const {
if (IsFinal(state)) {
// It is not permitted to request an oracle label from a sentence that is
// in a final state.
return -1;
} else {
return transition_system_->GetNextGoldAction(*(state->parser_state()));
}
}
void SyntaxNetComponent::Advance(SyntaxNetTransitionState *state, int action,
Beam<SyntaxNetTransitionState> *beam) {
auto parser_state = state->parser_state();
auto sentence_size = state->sentence()->sentence()->token_size();
const int num_steps = beam->num_steps();
if (transition_system_->SupportsActionMetaData()) {
const int parent_idx =
transition_system_->ParentIndex(*parser_state, action);
constexpr int kShiftAction = -1;
if (parent_idx == kShiftAction) {
if (parser_state->Next() < sentence_size && parser_state->Next() >= 0) {
// if we have already consumed all the input then it is not a shift
// action. We just skip it.
state->set_step_for_token(parser_state->Next(), num_steps);
}
} else if (parent_idx >= 0) {
VLOG(2) << spec_.name() << ": Updating pointer: " << parent_idx << " -> "
<< num_steps;
state->set_step_for_token(parent_idx, num_steps);
const int child_idx =
transition_system_->ChildIndex(*parser_state, action);
assert(child_idx >= 0 && child_idx < sentence_size);
state->set_parent_for_token(child_idx, parent_idx);
VLOG(2) << spec_.name() << ": Updating parent for child: " << parent_idx
<< " -> " << child_idx;
state->set_parent_step_for_token(child_idx, num_steps);
} else {
VLOG(2) << spec_.name() << ": Invalid parent index: " << parent_idx;
}
}
if (do_tracing_) {
auto *trace = state->mutable_trace();
auto *last_step = GetLastStepInTrace(trace);
// Add action to the prior step.
last_step->set_caption(
transition_system_->ActionAsString(action, *parser_state));
last_step->set_step_finished(true);
}
transition_system_->PerformAction(action, parser_state);
if (do_tracing_) {
// Add info for the next step.
*state->mutable_trace()->add_step_trace() = GetNewStepTrace(spec_, *state);
}
}
void SyntaxNetComponent::InitializeTracing() {
do_tracing_ = true;
CHECK(IsReady()) << "Cannot initialize trace before InitializeData().";
// Initialize each element of the beam with a new trace.
for (auto &beam : batch_) {
for (int beam_idx = 0; beam_idx < beam->size(); ++beam_idx) {
SyntaxNetTransitionState *state = beam->beam_state(beam_idx);
std::unique_ptr<ComponentTrace> trace(new ComponentTrace());
trace->set_name(spec_.name());
*trace->add_step_trace() = GetNewStepTrace(spec_, *state);
state->set_trace(std::move(trace));
}
}
feature_extractor_.set_add_strings(true);
}
void SyntaxNetComponent::DisableTracing() {
do_tracing_ = false;
feature_extractor_.set_add_strings(false);
}
void SyntaxNetComponent::AddTranslatedLinkFeaturesToTrace(
const std::vector<LinkFeatures> &features, int channel_id) {
CHECK(do_tracing_) << "Tracing is not enabled.";
int linear_idx = 0;
const int channel_size = spec_.linked_feature(channel_id).size();
// For every beam in the batch...
for (const auto &beam : batch_) {
// For every element in the beam...
for (int beam_idx = 0; beam_idx < max_beam_size_; ++beam_idx) {
for (int feature_idx = 0; feature_idx < channel_size; ++feature_idx) {
if (beam_idx < beam->size()) {
auto state = beam->beam_state(beam_idx);
auto *trace = GetLastStepInTrace(state->mutable_trace());
auto *link_trace = trace->mutable_linked_feature_trace(channel_id);
if (features[linear_idx].feature_value() >= 0 &&
features[linear_idx].step_idx() >= 0) {
*link_trace->add_value_trace() = features[linear_idx];
}
}
++linear_idx;
}
}
}
}
std::vector<std::vector<ComponentTrace>> SyntaxNetComponent::GetTraceProtos()
const {
std::vector<std::vector<ComponentTrace>> traces;
// For every beam in the batch...
for (const auto &beam : batch_) {
std::vector<ComponentTrace> beam_trace;
// For every element in the beam...
for (int beam_idx = 0; beam_idx < beam->size(); ++beam_idx) {
auto state = beam->beam_state(beam_idx);
beam_trace.push_back(*state->mutable_trace());
}
traces.push_back(beam_trace);
}
return traces;
};
REGISTER_DRAGNN_COMPONENT(SyntaxNetComponent);
} // namespace dragnn
} // namespace syntaxnet
#ifndef NLP_SAFT_OPENSOURCE_DRAGNN_COMPONENTS_SYNTAXNET_SYNTAXNET_COMPONENT_H_
#define NLP_SAFT_OPENSOURCE_DRAGNN_COMPONENTS_SYNTAXNET_SYNTAXNET_COMPONENT_H_
#include <vector>
#include "dragnn/components/syntaxnet/syntaxnet_link_feature_extractor.h"
#include "dragnn/components/syntaxnet/syntaxnet_transition_state.h"
#include "dragnn/components/util/bulk_feature_extractor.h"
#include "dragnn/core/beam.h"
#include "dragnn/core/input_batch_cache.h"
#include "dragnn/core/interfaces/component.h"
#include "dragnn/core/interfaces/transition_state.h"
#include "dragnn/protos/data.pb.h"
#include "dragnn/protos/spec.pb.h"
#include "dragnn/protos/trace.pb.h"
#include "syntaxnet/base.h"
#include "syntaxnet/parser_transitions.h"
#include "syntaxnet/registry.h"
#include "syntaxnet/task_context.h"
namespace syntaxnet {
namespace dragnn {
class SyntaxNetComponent : public Component {
public:
// Create a SyntaxNet-backed DRAGNN component.
SyntaxNetComponent();
// Initializes this component from the spec.
void InitializeComponent(const ComponentSpec &spec) override;
// Provides the previous beam to the component.
void InitializeData(
const std::vector<std::vector<const TransitionState *>> &states,
int max_beam_size, InputBatchCache *input_data) override;
// Returns true if the component has had InitializeData called on it since
// the last time it was reset.
bool IsReady() const override;
// Returns the string name of this component.
string Name() const override;
// Returns the number of steps taken by the given batch in this component.
int StepsTaken(int batch_index) const override;
// Returns the current batch size of the component's underlying data.
int BatchSize() const override;
// Returns the maximum beam size of this component.
int BeamSize() const override;
// Return the beam index of the item which is currently at index
// 'index', when the beam was at step 'step', for batch element 'batch'.
int GetBeamIndexAtStep(int step, int current_index, int batch) const override;
// Return the source index of the item which is currently at index 'index'
// for batch element 'batch'. This index is into the final beam of the
// Component that this Component was initialized from.
int GetSourceBeamIndex(int current_index, int batch) const override;
// Request a translation function based on the given method string.
// The translation function will be called with arguments (batch, beam, value)
// and should return the step index corresponding to the given value, for the
// data in the given beam and batch.
std::function<int(int, int, int)> GetStepLookupFunction(
const string &method) override;
// Advances this component from the given transition matrix.
void AdvanceFromPrediction(const float transition_matrix[],
int transition_matrix_length) override;
// Advances this component from the state oracles.
void AdvanceFromOracle() override;
// Returns true if all states within this component are terminal.
bool IsTerminal() const override;
// Returns the current batch of beams for this component.
std::vector<std::vector<const TransitionState *>> GetBeam() override;
// Extracts and populates the vector of FixedFeatures for the specified
// channel.
int GetFixedFeatures(std::function<int32 *(int)> allocate_indices,
std::function<int64 *(int)> allocate_ids,
std::function<float *(int)> allocate_weights,
int channel_id) const override;
// Extracts and populates all FixedFeatures for all channels, advancing this
// component via the oracle until it is terminal.
int BulkGetFixedFeatures(const BulkFeatureExtractor &extractor) override;
// Extracts and returns the vector of LinkFeatures for the specified
// channel. Note: these are NOT translated.
std::vector<LinkFeatures> GetRawLinkFeatures(int channel_id) const override;
// Returns a vector of oracle labels for each element in the beam and
// batch.
std::vector<std::vector<int>> GetOracleLabels() const override;
// Annotate the underlying data object with the results of this Component's
// calculation.
void FinalizeData() override;
// Reset this component.
void ResetComponent() override;
// Initializes the component for tracing execution. This will typically have
// the side effect of slowing down all subsequent Component calculations
// and storing a trace in memory that can be returned by GetTraceProtos().
void InitializeTracing() override;
// Disables tracing, freeing any additional memory and avoiding triggering
// additional computation in the future.
void DisableTracing() override;
std::vector<std::vector<ComponentTrace>> GetTraceProtos() const override;
void AddTranslatedLinkFeaturesToTrace(
const std::vector<LinkFeatures> &features, int channel_id) override;
private:
friend class SyntaxNetComponentTest;
friend class SyntaxNetTransitionStateTest;
// Permission function for this component.
bool IsAllowed(SyntaxNetTransitionState *state, int action) const;
// Returns true if this state is final
bool IsFinal(SyntaxNetTransitionState *state) const;
// Oracle function for this component.
int GetOracleLabel(SyntaxNetTransitionState *state) const;
// State advance function for this component.
void Advance(SyntaxNetTransitionState *state, int action,
Beam<SyntaxNetTransitionState> *beam);
// Creates a new state for the given nlp_saft::SentenceExample.
std::unique_ptr<SyntaxNetTransitionState> CreateState(
SyntaxNetSentence *example);
// Creates a newly initialized Beam.
std::unique_ptr<Beam<SyntaxNetTransitionState>> CreateBeam(int max_size);
// Transition system.
std::unique_ptr<ParserTransitionSystem> transition_system_;
// Label map for transition system.
const TermFrequencyMap *label_map_;
// Extractor for fixed features
ParserEmbeddingFeatureExtractor feature_extractor_;
// Extractor for linked features.
SyntaxNetLinkFeatureExtractor link_feature_extractor_;
// Internal workspace registry for use in feature extraction.
WorkspaceRegistry workspace_registry_;
// Switch for simulating legacy parser behaviour.
bool rewrite_root_labels_;
// The ComponentSpec used to initialize this component.
ComponentSpec spec_;
// State search beams
std::vector<std::unique_ptr<Beam<SyntaxNetTransitionState>>> batch_;
// Current max beam size.
int max_beam_size_;
// Underlying input data.
InputBatchCache *input_data_;
// Whether or not to trace for each batch and beam element.
bool do_tracing_ = false;
};
} // namespace dragnn
} // namespace syntaxnet
#endif // NLP_SAFT_OPENSOURCE_DRAGNN_COMPONENTS_SYNTAXNET_SYNTAXNET_COMPONENT_H_
#include "dragnn/components/syntaxnet/syntaxnet_component.h"
#include "dragnn/core/input_batch_cache.h"
#include "dragnn/core/test/generic.h"
#include "dragnn/core/test/mock_transition_state.h"
#include "dragnn/io/sentence_input_batch.h"
#include "syntaxnet/sentence.pb.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/platform/test.h"
// This test suite is intended to validate the contracts that the DRAGNN
// system expects from all transition state subclasses. Developers creating
// new TransitionStates should copy this test and modify it as necessary,
// using it to ensure their state conforms to DRAGNN expectations.
namespace syntaxnet {
namespace dragnn {
namespace {
const char kSentence0[] = R"(
token {
word: "Sentence" start: 0 end: 7 tag: "NN" category: "NOUN" label: "ROOT"
break_level: NO_BREAK
}
token {
word: "0" start: 9 end: 9 head: 0 tag: "CD" category: "NUM" label: "num"
break_level: SPACE_BREAK
}
token {
word: "." start: 10 end: 10 head: 0 tag: "." category: "." label: "punct"
break_level: NO_BREAK
}
)";
const char kSentence1[] = R"(
token {
word: "Sentence" start: 0 end: 7 tag: "NN" category: "NOUN" label: "ROOT"
break_level: NO_BREAK
}
token {
word: "1" start: 9 end: 9 head: 0 tag: "CD" category: "NUM" label: "num"
break_level: SPACE_BREAK
}
token {
word: "." start: 10 end: 10 head: 0 tag: "." category: "." label: "punct"
break_level: NO_BREAK
}
)";
const char kLongSentence[] = R"(
token {
word: "Sentence" start: 0 end: 7 tag: "NN" category: "NOUN" label: "ROOT"
break_level: NO_BREAK
}
token {
word: "1" start: 9 end: 9 head: 0 tag: "CD" category: "NUM" label: "num"
break_level: SPACE_BREAK
}
token {
word: "2" start: 10 end: 10 head: 0 tag: "CD" category: "NUM" label: "num"
break_level: SPACE_BREAK
}
token {
word: "3" start: 11 end: 11 head: 0 tag: "CD" category: "NUM" label: "num"
break_level: SPACE_BREAK
}
token {
word: "." start: 12 end: 12 head: 0 tag: "." category: "." label: "punct"
break_level: NO_BREAK
}
)";
} // namespace
using testing::Return;
class SyntaxNetComponentTest : public ::testing::Test {
public:
std::unique_ptr<SyntaxNetComponent> CreateParser(
const std::vector<std::vector<const TransitionState *>> &states,
const std::vector<string> &data) {
constexpr int kBeamSize = 2;
return CreateParserWithBeamSize(kBeamSize, states, data);
}
std::unique_ptr<SyntaxNetComponent> CreateParserWithBeamSize(
int beam_size,
const std::vector<std::vector<const TransitionState *>> &states,
const std::vector<string> &data) {
// Get the master spec proto from the test data directory.
MasterSpec master_spec;
string file_name = tensorflow::io::JoinPath(
test::GetTestDataPrefix(), "dragnn/components/syntaxnet/testdata",
"master_spec.textproto");
TF_CHECK_OK(tensorflow::ReadTextProto(tensorflow::Env::Default(), file_name,
&master_spec));
// Get all the resource protos from the test data directory.
for (Resource &resource :
*(master_spec.mutable_component(0)->mutable_resource())) {
resource.mutable_part(0)->set_file_pattern(tensorflow::io::JoinPath(
test::GetTestDataPrefix(), "dragnn/components/syntaxnet/testdata",
resource.part(0).file_pattern()));
}
data_.reset(new InputBatchCache(data));
// Create a parser component with the specified beam size.
std::unique_ptr<SyntaxNetComponent> parser_component(
new SyntaxNetComponent());
parser_component->InitializeComponent(*(master_spec.mutable_component(0)));
parser_component->InitializeData(states, beam_size, data_.get());
return parser_component;
}
const std::vector<Beam<SyntaxNetTransitionState> *> GetBeams(
SyntaxNetComponent *component) const {
std::vector<Beam<SyntaxNetTransitionState> *> return_vector;
for (const auto &beam : component->batch_) {
return_vector.push_back(beam.get());
}
return return_vector;
}
std::unique_ptr<InputBatchCache> data_;
};
TEST_F(SyntaxNetComponentTest, AdvancesFromOracleAndTerminates) {
// Create an empty input batch and beam vector to initialize the parser.
Sentence sentence_0;
TextFormat::ParseFromString(kSentence0, &sentence_0);
string sentence_0_str;
sentence_0.SerializeToString(&sentence_0_str);
auto test_parser = CreateParser({}, {sentence_0_str});
constexpr int kNumTokensInSentence = 3;
// The master spec will initialize a parser, so expect 2*N transitions.
constexpr int kExpectedNumTransitions = kNumTokensInSentence * 2;
for (int i = 0; i < kExpectedNumTransitions; ++i) {
EXPECT_FALSE(test_parser->IsTerminal());
test_parser->AdvanceFromOracle();
}
// At this point, the test parser should be terminal.
EXPECT_TRUE(test_parser->IsTerminal());
// Check that the component is reporting 2N steps taken.
EXPECT_EQ(test_parser->StepsTaken(0), kExpectedNumTransitions);
// Make sure the parser doesn't segfault.
test_parser->FinalizeData();
}
TEST_F(SyntaxNetComponentTest, AdvancesFromPredictionAndTerminates) {
// Create an empty input batch and beam vector to initialize the parser.
Sentence sentence_0;
TextFormat::ParseFromString(kSentence0, &sentence_0);
string sentence_0_str;
sentence_0.SerializeToString(&sentence_0_str);
auto test_parser = CreateParser({}, {sentence_0_str});
constexpr int kNumTokensInSentence = 3;
// The master spec will initialize a parser, so expect 2*N transitions.
constexpr int kExpectedNumTransitions = kNumTokensInSentence * 2;
// There are 93 possible transitions for any given state. Create a transition
// array with a score of 10.0 for each transition.
constexpr int kBeamSize = 2;
constexpr int kNumPossibleTransitions = 93;
constexpr float kTransitionValue = 10.0;
float transition_matrix[kNumPossibleTransitions * kBeamSize];
for (int i = 0; i < kNumPossibleTransitions * kBeamSize; ++i) {
transition_matrix[i] = kTransitionValue;
}
// Transition the expected number of times.
for (int i = 0; i < kExpectedNumTransitions; ++i) {
EXPECT_FALSE(test_parser->IsTerminal());
test_parser->AdvanceFromPrediction(transition_matrix,
kNumPossibleTransitions * kBeamSize);
}
// At this point, the test parser should be terminal.
EXPECT_TRUE(test_parser->IsTerminal());
// Check that the component is reporting 2N steps taken.
EXPECT_EQ(test_parser->StepsTaken(0), kExpectedNumTransitions);
// Prepare to validate the batched beams.
auto beam = test_parser->GetBeam();
// All beams should only have one element.
for (const auto &per_beam : beam) {
EXPECT_EQ(per_beam.size(), 1);
}
// The final states should have kExpectedNumTransitions * kTransitionValue.
EXPECT_EQ(beam.at(0).at(0)->GetScore(),
kTransitionValue * kExpectedNumTransitions);
// Make sure the parser doesn't segfault.
test_parser->FinalizeData();
// TODO(googleuser): What should the finalized data look like?
}
TEST_F(SyntaxNetComponentTest, RetainsPassedTransitionStateData) {
// Create and initialize the state->
MockTransitionState mock_state_one;
constexpr int kParentBeamIndexOne = 1138;
constexpr float kParentScoreOne = 7.2;
EXPECT_CALL(mock_state_one, GetBeamIndex())
.WillRepeatedly(Return(kParentBeamIndexOne));
EXPECT_CALL(mock_state_one, GetScore())
.WillRepeatedly(Return(kParentScoreOne));
MockTransitionState mock_state_two;
constexpr int kParentBeamIndexTwo = 1123;
constexpr float kParentScoreTwo = 42.03;
EXPECT_CALL(mock_state_two, GetBeamIndex())
.WillRepeatedly(Return(kParentBeamIndexTwo));
EXPECT_CALL(mock_state_two, GetScore())
.WillRepeatedly(Return(kParentScoreTwo));
// Create an empty input batch and beam vector to initialize the parser.
Sentence sentence_0;
TextFormat::ParseFromString(kSentence0, &sentence_0);
string sentence_0_str;
sentence_0.SerializeToString(&sentence_0_str);
auto test_parser =
CreateParser({{&mock_state_one, &mock_state_two}}, {sentence_0_str});
constexpr int kNumTokensInSentence = 3;
// The master spec will initialize a parser, so expect 2*N transitions.
constexpr int kExpectedNumTransitions = kNumTokensInSentence * 2;
// There are 93 possible transitions for any given state. Create a transition
// array with a score of 10.0 for each transition.
constexpr int kBeamSize = 2;
constexpr int kNumPossibleTransitions = 93;
constexpr float kTransitionValue = 10.0;
float transition_matrix[kNumPossibleTransitions * kBeamSize];
for (int i = 0; i < kNumPossibleTransitions * kBeamSize; ++i) {
transition_matrix[i] = kTransitionValue;
}
// Transition the expected number of times
for (int i = 0; i < kExpectedNumTransitions; ++i) {
EXPECT_FALSE(test_parser->IsTerminal());
test_parser->AdvanceFromPrediction(transition_matrix,
kNumPossibleTransitions * kBeamSize);
}
// At this point, the test parser should be terminal.
EXPECT_TRUE(test_parser->IsTerminal());
// Check that the component is reporting 2N steps taken.
EXPECT_EQ(test_parser->StepsTaken(0), kExpectedNumTransitions);
// The final states should have kExpectedNumTransitions * kTransitionValue,
// plus the higher parent state score (from state two).
auto beam = test_parser->GetBeam();
EXPECT_EQ(beam.at(0).at(0)->GetScore(),
kTransitionValue * kExpectedNumTransitions + kParentScoreTwo);
// Make sure that the parent state is reported correctly.
EXPECT_EQ(test_parser->GetSourceBeamIndex(0, 0), kParentBeamIndexTwo);
// Make sure the parser doesn't segfault.
test_parser->FinalizeData();
// TODO(googleuser): What should the finalized data look like?
}
TEST_F(SyntaxNetComponentTest, AdvancesFromPredictionForMultiSentenceBatches) {
// Create an empty input batch and beam vector to initialize the parser.
Sentence sentence_0;
TextFormat::ParseFromString(kSentence0, &sentence_0);
string sentence_0_str;
sentence_0.SerializeToString(&sentence_0_str);
Sentence sentence_1;
TextFormat::ParseFromString(kSentence1, &sentence_1);
string sentence_1_str;
sentence_1.SerializeToString(&sentence_1_str);
auto test_parser = CreateParser({}, {sentence_0_str, sentence_1_str});
constexpr int kNumTokensInSentence = 3;
// The master spec will initialize a parser, so expect 2*N transitions.
constexpr int kExpectedNumTransitions = kNumTokensInSentence * 2;
// There are 93 possible transitions for any given state. Create a transition
// array with a score of 10.0 for each transition.
constexpr int kBatchSize = 2;
constexpr int kBeamSize = 2;
constexpr int kNumPossibleTransitions = 93;
constexpr float kTransitionValue = 10.0;
float transition_matrix[kNumPossibleTransitions * kBeamSize * kBatchSize];
for (int i = 0; i < kNumPossibleTransitions * kBeamSize * kBatchSize; ++i) {
transition_matrix[i] = kTransitionValue;
}
// Transition the expected number of times.
for (int i = 0; i < kExpectedNumTransitions; ++i) {
EXPECT_FALSE(test_parser->IsTerminal());
test_parser->AdvanceFromPrediction(
transition_matrix, kNumPossibleTransitions * kBeamSize * kBatchSize);
}
// At this point, the test parser should be terminal.
EXPECT_TRUE(test_parser->IsTerminal());
// Check that the component is reporting 2N steps taken.
EXPECT_EQ(test_parser->StepsTaken(0), kExpectedNumTransitions);
EXPECT_EQ(test_parser->StepsTaken(1), kExpectedNumTransitions);
// The final states should have kExpectedNumTransitions * kTransitionValue.
auto beam = test_parser->GetBeam();
EXPECT_EQ(beam.at(0).at(0)->GetScore(),
kTransitionValue * kExpectedNumTransitions);
EXPECT_EQ(beam.at(1).at(0)->GetScore(),
kTransitionValue * kExpectedNumTransitions);
// Make sure the parser doesn't segfault.
test_parser->FinalizeData();
// TODO(googleuser): What should the finalized data look like?
}
TEST_F(SyntaxNetComponentTest,
AdvancesFromPredictionForVaryingLengthSentences) {
// Create an empty input batch and beam vector to initialize the parser.
Sentence sentence_0;
TextFormat::ParseFromString(kSentence0, &sentence_0);
string sentence_0_str;
sentence_0.SerializeToString(&sentence_0_str);
Sentence long_sentence;
TextFormat::ParseFromString(kLongSentence, &long_sentence);
string long_sentence_str;
long_sentence.SerializeToString(&long_sentence_str);
auto test_parser = CreateParser({}, {sentence_0_str, long_sentence_str});
constexpr int kNumTokensInSentence = 3;
constexpr int kNumTokensInLongSentence = 5;
// There are 93 possible transitions for any given state. Create a transition
// array with a score of 10.0 for each transition.
constexpr int kBatchSize = 2;
constexpr int kBeamSize = 2;
constexpr int kNumPossibleTransitions = 93;
constexpr float kTransitionValue = 10.0;
float transition_matrix[kNumPossibleTransitions * kBeamSize * kBatchSize];
for (int i = 0; i < kNumPossibleTransitions * kBeamSize * kBatchSize; ++i) {
transition_matrix[i] = kTransitionValue;
}
// Transition the expected number of times.
constexpr int kExpectedNumTransitions = kNumTokensInLongSentence * 2;
for (int i = 0; i < kExpectedNumTransitions; ++i) {
EXPECT_FALSE(test_parser->IsTerminal());
test_parser->AdvanceFromPrediction(
transition_matrix, kNumPossibleTransitions * kBeamSize * kBatchSize);
}
// At this point, the test parser should be terminal.
EXPECT_TRUE(test_parser->IsTerminal());
// Check that the component is reporting 2N steps taken.
EXPECT_EQ(test_parser->StepsTaken(0), kNumTokensInSentence * 2);
EXPECT_EQ(test_parser->StepsTaken(1), kNumTokensInLongSentence * 2);
// The final states should have kExpectedNumTransitions * kTransitionValue.
auto beam = test_parser->GetBeam();
// The first sentence is shorter, so it should have a lower final score.
EXPECT_EQ(beam.at(0).at(0)->GetScore(),
kTransitionValue * kNumTokensInSentence * 2);
EXPECT_EQ(beam.at(1).at(0)->GetScore(),
kTransitionValue * kNumTokensInLongSentence * 2);
// Make sure the parser doesn't segfault.
test_parser->FinalizeData();
// TODO(googleuser): What should the finalized data look like?
}
TEST_F(SyntaxNetComponentTest, ResetAllowsReductionInBatchSize) {
// Create an empty input batch and beam vector to initialize the parser.
Sentence sentence_0;
TextFormat::ParseFromString(kSentence0, &sentence_0);
string sentence_0_str;
sentence_0.SerializeToString(&sentence_0_str);
Sentence long_sentence;
TextFormat::ParseFromString(kLongSentence, &long_sentence);
string long_sentence_str;
long_sentence.SerializeToString(&long_sentence_str);
// Get the master spec proto from the test data directory.
MasterSpec master_spec;
string file_name = tensorflow::io::JoinPath(
test::GetTestDataPrefix(), "dragnn/components/syntaxnet/testdata",
"master_spec.textproto");
TF_CHECK_OK(tensorflow::ReadTextProto(tensorflow::Env::Default(), file_name,
&master_spec));
// Get all the resource protos from the test data directory.
for (Resource &resource :
*(master_spec.mutable_component(0)->mutable_resource())) {
resource.mutable_part(0)->set_file_pattern(tensorflow::io::JoinPath(
test::GetTestDataPrefix(), "dragnn/components/syntaxnet/testdata",
resource.part(0).file_pattern()));
}
// Create an input batch cache with a large batch size.
constexpr int kBeamSize = 2;
std::unique_ptr<InputBatchCache> large_batch_data(new InputBatchCache(
{sentence_0_str, sentence_0_str, sentence_0_str, sentence_0_str}));
std::unique_ptr<SyntaxNetComponent> parser_component(
new SyntaxNetComponent());
parser_component->InitializeComponent(*(master_spec.mutable_component(0)));
parser_component->InitializeData({}, kBeamSize, large_batch_data.get());
// Reset the component and pass in a new input batch that is smaller.
parser_component->ResetComponent();
std::unique_ptr<InputBatchCache> small_batch_data(new InputBatchCache(
{long_sentence_str, long_sentence_str, long_sentence_str}));
parser_component->InitializeData({}, kBeamSize, small_batch_data.get());
// There are 93 possible transitions for any given state. Create a transition
// array with a score of 10.0 for each transition.
constexpr int kBatchSize = 3;
constexpr int kNumPossibleTransitions = 93;
constexpr float kTransitionValue = 10.0;
float transition_matrix[kNumPossibleTransitions * kBeamSize * kBatchSize];
for (int i = 0; i < kNumPossibleTransitions * kBeamSize * kBatchSize; ++i) {
transition_matrix[i] = kTransitionValue;
}
// Transition the expected number of times.
constexpr int kNumTokensInSentence = 5;
constexpr int kExpectedNumTransitions = kNumTokensInSentence * 2;
for (int i = 0; i < kExpectedNumTransitions; ++i) {
EXPECT_FALSE(parser_component->IsTerminal());
parser_component->AdvanceFromPrediction(
transition_matrix, kNumPossibleTransitions * kBeamSize * kBatchSize);
}
// At this point, the test parser should be terminal.
EXPECT_TRUE(parser_component->IsTerminal());
// Check that the component is reporting 2N steps taken.
EXPECT_EQ(parser_component->StepsTaken(0), kExpectedNumTransitions);
EXPECT_EQ(parser_component->StepsTaken(1), kExpectedNumTransitions);
EXPECT_EQ(parser_component->StepsTaken(2), kExpectedNumTransitions);
// The final states should have kExpectedNumTransitions * kTransitionValue.
auto beam = parser_component->GetBeam();
// The beam should be of batch size 3.
EXPECT_EQ(beam.size(), 3);
EXPECT_EQ(beam.at(0).at(0)->GetScore(),
kTransitionValue * kExpectedNumTransitions);
EXPECT_EQ(beam.at(1).at(0)->GetScore(),
kTransitionValue * kExpectedNumTransitions);
EXPECT_EQ(beam.at(2).at(0)->GetScore(),
kTransitionValue * kExpectedNumTransitions);
// Make sure the parser doesn't segfault.
parser_component->FinalizeData();
}
TEST_F(SyntaxNetComponentTest, ResetAllowsIncreaseInBatchSize) {
// Create an empty input batch and beam vector to initialize the parser.
Sentence sentence_0;
TextFormat::ParseFromString(kSentence0, &sentence_0);
string sentence_0_str;
sentence_0.SerializeToString(&sentence_0_str);
Sentence long_sentence;
TextFormat::ParseFromString(kLongSentence, &long_sentence);
string long_sentence_str;
long_sentence.SerializeToString(&long_sentence_str);
// Get the master spec proto from the test data directory.
MasterSpec master_spec;
string file_name = tensorflow::io::JoinPath(
test::GetTestDataPrefix(), "dragnn/components/syntaxnet/testdata",
"master_spec.textproto");
TF_CHECK_OK(tensorflow::ReadTextProto(tensorflow::Env::Default(), file_name,
&master_spec));
// Get all the resource protos from the test data directory.
for (Resource &resource :
*(master_spec.mutable_component(0)->mutable_resource())) {
resource.mutable_part(0)->set_file_pattern(tensorflow::io::JoinPath(
test::GetTestDataPrefix(), "dragnn/components/syntaxnet/testdata",
resource.part(0).file_pattern()));
}
// Create an input batch cache with a small batch size.
constexpr int kBeamSize = 2;
std::unique_ptr<InputBatchCache> small_batch_data(
new InputBatchCache(sentence_0_str));
std::unique_ptr<SyntaxNetComponent> parser_component(
new SyntaxNetComponent());
parser_component->InitializeComponent(*(master_spec.mutable_component(0)));
parser_component->InitializeData({}, kBeamSize, small_batch_data.get());
// Reset the component and pass in a new input batch that is larger.
parser_component->ResetComponent();
std::unique_ptr<InputBatchCache> large_batch_data(new InputBatchCache(
{long_sentence_str, long_sentence_str, long_sentence_str}));
parser_component->InitializeData({}, kBeamSize, large_batch_data.get());
// There are 93 possible transitions for any given state. Create a transition
// array with a score of 10.0 for each transition.
constexpr int kBatchSize = 3;
constexpr int kNumPossibleTransitions = 93;
constexpr float kTransitionValue = 10.0;
float transition_matrix[kNumPossibleTransitions * kBeamSize * kBatchSize];
for (int i = 0; i < kNumPossibleTransitions * kBeamSize * kBatchSize; ++i) {
transition_matrix[i] = kTransitionValue;
}
// Transition the expected number of times.
constexpr int kNumTokensInSentence = 5;
constexpr int kExpectedNumTransitions = kNumTokensInSentence * 2;
for (int i = 0; i < kExpectedNumTransitions; ++i) {
EXPECT_FALSE(parser_component->IsTerminal());
parser_component->AdvanceFromPrediction(
transition_matrix, kNumPossibleTransitions * kBeamSize * kBatchSize);
}
// At this point, the test parser should be terminal.
EXPECT_TRUE(parser_component->IsTerminal());
// Check that the component is reporting 2N steps taken.
EXPECT_EQ(parser_component->StepsTaken(0), kExpectedNumTransitions);
EXPECT_EQ(parser_component->StepsTaken(1), kExpectedNumTransitions);
EXPECT_EQ(parser_component->StepsTaken(2), kExpectedNumTransitions);
// The final states should have kExpectedNumTransitions * kTransitionValue.
auto beam = parser_component->GetBeam();
// The beam should be of batch size 3.
EXPECT_EQ(beam.size(), 3);
EXPECT_EQ(beam.at(0).at(0)->GetScore(),
kTransitionValue * kExpectedNumTransitions);
EXPECT_EQ(beam.at(1).at(0)->GetScore(),
kTransitionValue * kExpectedNumTransitions);
EXPECT_EQ(beam.at(2).at(0)->GetScore(),
kTransitionValue * kExpectedNumTransitions);
// Make sure the parser doesn't segfault.
parser_component->FinalizeData();
}
TEST_F(SyntaxNetComponentTest, ResetCausesBeamToReset) {
// Create an empty input batch and beam vector to initialize the parser.
Sentence sentence_0;
TextFormat::ParseFromString(kSentence0, &sentence_0);
string sentence_0_str;
sentence_0.SerializeToString(&sentence_0_str);
Sentence long_sentence;
TextFormat::ParseFromString(kLongSentence, &long_sentence);
string long_sentence_str;
long_sentence.SerializeToString(&long_sentence_str);
auto test_parser = CreateParser({}, {sentence_0_str});
constexpr int kNumTokensInSentence = 3;
// The master spec will initialize a parser, so expect 2*N transitions.
constexpr int kExpectedNumTransitions = kNumTokensInSentence * 2;
// There are 93 possible transitions for any given state. Create a transition
// array with a score of 10.0 for each transition.
constexpr int kBeamSize = 2;
constexpr int kNumPossibleTransitions = 93;
constexpr float kTransitionValue = 10.0;
float transition_matrix[kNumPossibleTransitions * kBeamSize];
for (int i = 0; i < kNumPossibleTransitions * kBeamSize; ++i) {
transition_matrix[i] = kTransitionValue;
}
// Transition the expected number of times.
for (int i = 0; i < kExpectedNumTransitions; ++i) {
EXPECT_FALSE(test_parser->IsTerminal());
test_parser->AdvanceFromPrediction(transition_matrix,
kNumPossibleTransitions * kBeamSize);
}
// At this point, the test parser should be terminal.
EXPECT_TRUE(test_parser->IsTerminal());
// Check that the component is reporting 2N steps taken.
EXPECT_EQ(test_parser->StepsTaken(0), kExpectedNumTransitions);
// The final states should have kExpectedNumTransitions * kTransitionValue.
auto beam = test_parser->GetBeam();
EXPECT_EQ(beam.at(0).at(0)->GetScore(),
kTransitionValue * kExpectedNumTransitions);
// Reset the test parser and give it new data.
test_parser->ResetComponent();
std::unique_ptr<InputBatchCache> new_data(
new InputBatchCache(long_sentence_str));
test_parser->InitializeData({}, kBeamSize, new_data.get());
// Check that the component is not terminal.
EXPECT_FALSE(test_parser->IsTerminal());
// Check that the component is reporting 0 steps taken.
EXPECT_EQ(test_parser->StepsTaken(0), 0);
// The states should have 0 as their score.
auto new_beam = test_parser->GetBeam();
EXPECT_EQ(new_beam.at(0).at(0)->GetScore(), 0);
}
TEST_F(SyntaxNetComponentTest, AdjustingMaxBeamSizeAdjustsSizeForAllBeams) {
// Create an empty input batch and beam vector to initialize the parser.
Sentence sentence_0;
TextFormat::ParseFromString(kSentence0, &sentence_0);
string sentence_0_str;
sentence_0.SerializeToString(&sentence_0_str);
Sentence long_sentence;
TextFormat::ParseFromString(kLongSentence, &long_sentence);
string long_sentence_str;
long_sentence.SerializeToString(&long_sentence_str);
// Get the master spec proto from the test data directory.
MasterSpec master_spec;
string file_name = tensorflow::io::JoinPath(
test::GetTestDataPrefix(), "dragnn/components/syntaxnet/testdata",
"master_spec.textproto");
TF_CHECK_OK(tensorflow::ReadTextProto(tensorflow::Env::Default(), file_name,
&master_spec));
// Get all the resource protos from the test data directory.
for (Resource &resource :
*(master_spec.mutable_component(0)->mutable_resource())) {
resource.mutable_part(0)->set_file_pattern(tensorflow::io::JoinPath(
test::GetTestDataPrefix(), "dragnn/components/syntaxnet/testdata",
resource.part(0).file_pattern()));
}
// Create an input batch cache with a small batch size.
constexpr int kBeamSize = 2;
std::unique_ptr<InputBatchCache> small_batch_data(
new InputBatchCache(sentence_0_str));
std::unique_ptr<SyntaxNetComponent> parser_component(
new SyntaxNetComponent());
parser_component->InitializeComponent(*(master_spec.mutable_component(0)));
parser_component->InitializeData({}, kBeamSize, small_batch_data.get());
// Make sure all the beams in the batch have max size 2.
for (const auto &beam : GetBeams(parser_component.get())) {
EXPECT_EQ(beam->max_size(), kBeamSize);
}
// Reset the component and pass in a new input batch that is larger, with
// a higher beam size.
constexpr int kNewBeamSize = 5;
parser_component->ResetComponent();
std::unique_ptr<InputBatchCache> large_batch_data(new InputBatchCache(
{long_sentence_str, long_sentence_str, long_sentence_str}));
parser_component->InitializeData({}, kNewBeamSize, large_batch_data.get());
// Make sure all the beams in the batch now have max size 5.
for (const auto &beam : GetBeams(parser_component.get())) {
EXPECT_EQ(beam->max_size(), kNewBeamSize);
}
}
TEST_F(SyntaxNetComponentTest, SettingBeamSizeZeroFails) {
// Create an empty input batch and beam vector to initialize the parser.
Sentence sentence_0;
TextFormat::ParseFromString(kSentence0, &sentence_0);
string sentence_0_str;
sentence_0.SerializeToString(&sentence_0_str);
Sentence long_sentence;
TextFormat::ParseFromString(kLongSentence, &long_sentence);
string long_sentence_str;
long_sentence.SerializeToString(&long_sentence_str);
// Get the master spec proto from the test data directory.
MasterSpec master_spec;
string file_name = tensorflow::io::JoinPath(
test::GetTestDataPrefix(), "dragnn/components/syntaxnet/testdata",
"master_spec.textproto");
TF_CHECK_OK(tensorflow::ReadTextProto(tensorflow::Env::Default(), file_name,
&master_spec));
// Get all the resource protos from the test data directory.
for (Resource &resource :
*(master_spec.mutable_component(0)->mutable_resource())) {
resource.mutable_part(0)->set_file_pattern(tensorflow::io::JoinPath(
test::GetTestDataPrefix(), "dragnn/components/syntaxnet/testdata",
resource.part(0).file_pattern()));
}
// Create an input batch cache with a small batch size.
constexpr int kBeamSize = 0;
std::unique_ptr<InputBatchCache> small_batch_data(
new InputBatchCache(sentence_0_str));
std::unique_ptr<SyntaxNetComponent> parser_component(
new SyntaxNetComponent());
parser_component->InitializeComponent(*(master_spec.mutable_component(0)));
EXPECT_DEATH(
parser_component->InitializeData({}, kBeamSize, small_batch_data.get()),
"must be greater than 0");
}
TEST_F(SyntaxNetComponentTest, ExportsFixedFeaturesWithPadding) {
// Create an empty input batch and beam vector to initialize the parser.
Sentence sentence_0;
TextFormat::ParseFromString(kSentence0, &sentence_0);
string sentence_0_str;
sentence_0.SerializeToString(&sentence_0_str);
Sentence sentence_1;
TextFormat::ParseFromString(kSentence1, &sentence_1);
string sentence_1_str;
sentence_1.SerializeToString(&sentence_1_str);
constexpr int kBeamSize = 3;
auto test_parser =
CreateParserWithBeamSize(kBeamSize, {}, {sentence_0_str, sentence_1_str});
// Get and check the raw link features.
vector<int32> indices;
auto indices_fn = [&indices](int size) {
indices.resize(size);
return indices.data();
};
vector<int64> ids;
auto ids_fn = [&ids](int size) {
ids.resize(size);
return ids.data();
};
vector<float> weights;
auto weights_fn = [&weights](int size) {
weights.resize(size);
return weights.data();
};
constexpr int kChannelId = 0;
const int num_features =
test_parser->GetFixedFeatures(indices_fn, ids_fn, weights_fn, kChannelId);
// The raw features for each beam object should be [single, single].
// There is also padding expected in this beam - there is only one
// element in each beam (so two elements total; batch is two). Thus, we expect
// 0,1 and 6,7 to be filled with one element each.
constexpr int kExpectedOutputSize = 4;
const vector<int32> expected_indices({0, 1, 6, 7});
const vector<int64> expected_ids({0, 12, 0, 12});
const vector<float> expected_weights({1.0, 1.0, 1.0, 1.0});
EXPECT_EQ(expected_indices.size(), kExpectedOutputSize);
EXPECT_EQ(expected_ids.size(), kExpectedOutputSize);
EXPECT_EQ(expected_weights.size(), kExpectedOutputSize);
EXPECT_EQ(num_features, kExpectedOutputSize);
EXPECT_EQ(expected_indices, indices);
EXPECT_EQ(expected_ids, ids);
EXPECT_EQ(expected_weights, weights);
}
TEST_F(SyntaxNetComponentTest, ExportsFixedFeatures) {
// Create an empty input batch and beam vector to initialize the parser.
Sentence sentence_0;
TextFormat::ParseFromString(kSentence0, &sentence_0);
string sentence_0_str;
sentence_0.SerializeToString(&sentence_0_str);
Sentence sentence_1;
TextFormat::ParseFromString(kSentence1, &sentence_1);
string sentence_1_str;
sentence_1.SerializeToString(&sentence_1_str);
constexpr int kBeamSize = 3;
auto test_parser =
CreateParserWithBeamSize(kBeamSize, {}, {sentence_0_str, sentence_1_str});
// There are 93 possible transitions for any given state. Create a transition
// array with a score of 10.0 for each transition.
constexpr int kBatchSize = 2;
constexpr int kNumPossibleTransitions = 93;
constexpr float kTransitionValue = 10.0;
float transition_matrix[kNumPossibleTransitions * kBeamSize * kBatchSize];
for (int i = 0; i < kNumPossibleTransitions * kBeamSize * kBatchSize; ++i) {
transition_matrix[i] = kTransitionValue;
}
// Advance twice, so that the underlying parser fills the beam.
test_parser->AdvanceFromPrediction(
transition_matrix, kNumPossibleTransitions * kBeamSize * kBatchSize);
test_parser->AdvanceFromPrediction(
transition_matrix, kNumPossibleTransitions * kBeamSize * kBatchSize);
// Get and check the raw link features.
vector<int32> indices;
auto indices_fn = [&indices](int size) {
indices.resize(size);
return indices.data();
};
vector<int64> ids;
auto ids_fn = [&ids](int size) {
ids.resize(size);
return ids.data();
};
vector<float> weights;
auto weights_fn = [&weights](int size) {
weights.resize(size);
return weights.data();
};
constexpr int kChannelId = 0;
const int num_features =
test_parser->GetFixedFeatures(indices_fn, ids_fn, weights_fn, kChannelId);
// In this case, all even features and all odd features are identical.
constexpr int kExpectedOutputSize = 12;
const vector<int32> expected_indices({0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11});
const vector<int64> expected_ids({12, 7, 12, 7, 12, 7, 12, 7, 12, 7, 12, 7});
const vector<float> expected_weights(
{1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0});
EXPECT_EQ(expected_indices.size(), kExpectedOutputSize);
EXPECT_EQ(expected_ids.size(), kExpectedOutputSize);
EXPECT_EQ(expected_weights.size(), kExpectedOutputSize);
EXPECT_EQ(num_features, kExpectedOutputSize);
EXPECT_EQ(expected_indices, indices);
EXPECT_EQ(expected_ids, ids);
EXPECT_EQ(expected_weights, weights);
}
TEST_F(SyntaxNetComponentTest, ExportsBulkFixedFeatures) {
// Create an empty input batch and beam vector to initialize the parser.
Sentence sentence_0;
TextFormat::ParseFromString(kSentence0, &sentence_0);
string sentence_0_str;
sentence_0.SerializeToString(&sentence_0_str);
Sentence sentence_1;
TextFormat::ParseFromString(kSentence1, &sentence_1);
string sentence_1_str;
sentence_1.SerializeToString(&sentence_1_str);
constexpr int kBeamSize = 3;
auto test_parser =
CreateParserWithBeamSize(kBeamSize, {}, {sentence_0_str, sentence_1_str});
// Get and check the raw link features.
vector<vector<int32>> indices;
auto indices_fn = [&indices](int channel, int size) {
indices.resize(channel + 1);
indices[channel].resize(size);
return indices[channel].data();
};
vector<vector<int64>> ids;
auto ids_fn = [&ids](int channel, int size) {
ids.resize(channel + 1);
ids[channel].resize(size);
return ids[channel].data();
};
vector<vector<float>> weights;
auto weights_fn = [&weights](int channel, int size) {
weights.resize(channel + 1);
weights[channel].resize(size);
return weights[channel].data();
};
BulkFeatureExtractor extractor(indices_fn, ids_fn, weights_fn);
const int num_steps = test_parser->BulkGetFixedFeatures(extractor);
// There should be 6 steps (2N, where N is the longest number of tokens).
EXPECT_EQ(num_steps, 6);
// These are empirically derived.
const vector<int32> expected_ch0_indices({0, 36, 18, 54, 1, 37, 19, 55,
2, 38, 20, 56, 3, 39, 21, 57,
4, 40, 22, 58, 5, 41, 23, 59});
const vector<int64> expected_ch0_ids({0, 12, 0, 12, 12, 7, 12, 7,
7, 50, 7, 50, 7, 50, 7, 50,
50, 50, 50, 50, 50, 50, 50, 50});
const vector<float> expected_ch0_weights(
{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1});
const vector<int32> expected_ch1_indices(
{0, 36, 72, 18, 54, 90, 1, 37, 73, 19, 55, 91, 2, 38, 74, 20, 56, 92,
3, 39, 75, 21, 57, 93, 4, 40, 76, 22, 58, 94, 5, 41, 77, 23, 59, 95});
const vector<int64> expected_ch1_ids(
{51, 0, 12, 51, 0, 12, 0, 12, 7, 0, 12, 7, 12, 7, 50, 12, 7, 50,
12, 7, 50, 12, 7, 50, 7, 50, 50, 7, 50, 50, 7, 50, 50, 7, 50, 50});
const vector<float> expected_ch1_weights(
{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1});
EXPECT_EQ(indices[0], expected_ch0_indices);
EXPECT_EQ(ids[0], expected_ch0_ids);
EXPECT_EQ(weights[0], expected_ch0_weights);
EXPECT_EQ(indices[1], expected_ch1_indices);
EXPECT_EQ(ids[1], expected_ch1_ids);
EXPECT_EQ(weights[1], expected_ch1_weights);
}
TEST_F(SyntaxNetComponentTest, ExportsRawLinkFeaturesWithPadding) {
// Create an empty input batch and beam vector to initialize the parser.
Sentence sentence_0;
TextFormat::ParseFromString(kSentence0, &sentence_0);
string sentence_0_str;
sentence_0.SerializeToString(&sentence_0_str);
Sentence sentence_1;
TextFormat::ParseFromString(kSentence1, &sentence_1);
string sentence_1_str;
sentence_1.SerializeToString(&sentence_1_str);
constexpr int kBeamSize = 3;
constexpr int kBatchSize = 2;
auto test_parser =
CreateParserWithBeamSize(kBeamSize, {}, {sentence_0_str, sentence_1_str});
// Get and check the raw link features.
constexpr int kNumLinkFeatures = 2;
auto link_features = test_parser->GetRawLinkFeatures(0);
EXPECT_EQ(link_features.size(), kBeamSize * kBatchSize * kNumLinkFeatures);
EXPECT_EQ(link_features.at(0).feature_value(), -1);
EXPECT_EQ(link_features.at(0).batch_idx(), 0);
EXPECT_EQ(link_features.at(0).beam_idx(), 0);
EXPECT_EQ(link_features.at(1).feature_value(), -2);
EXPECT_EQ(link_features.at(1).batch_idx(), 0);
EXPECT_EQ(link_features.at(1).beam_idx(), 0);
// These are padding, so we do not expect them to have a feature value.
EXPECT_FALSE(link_features.at(2).has_feature_value());
EXPECT_FALSE(link_features.at(2).has_batch_idx());
EXPECT_FALSE(link_features.at(2).has_beam_idx());
EXPECT_FALSE(link_features.at(3).has_feature_value());
EXPECT_FALSE(link_features.at(3).has_batch_idx());
EXPECT_FALSE(link_features.at(3).has_beam_idx());
EXPECT_FALSE(link_features.at(4).has_feature_value());
EXPECT_FALSE(link_features.at(4).has_batch_idx());
EXPECT_FALSE(link_features.at(4).has_beam_idx());
EXPECT_FALSE(link_features.at(5).has_feature_value());
EXPECT_FALSE(link_features.at(5).has_batch_idx());
EXPECT_FALSE(link_features.at(5).has_beam_idx());
EXPECT_EQ(link_features.at(6).feature_value(), -1);
EXPECT_EQ(link_features.at(6).batch_idx(), 1);
EXPECT_EQ(link_features.at(6).beam_idx(), 0);
EXPECT_EQ(link_features.at(7).feature_value(), -2);
EXPECT_EQ(link_features.at(7).batch_idx(), 1);
EXPECT_EQ(link_features.at(7).beam_idx(), 0);
// These are padding, so we do not expect them to have a feature value.
EXPECT_FALSE(link_features.at(8).has_feature_value());
EXPECT_FALSE(link_features.at(8).has_batch_idx());
EXPECT_FALSE(link_features.at(8).has_beam_idx());
EXPECT_FALSE(link_features.at(9).has_feature_value());
EXPECT_FALSE(link_features.at(9).has_batch_idx());
EXPECT_FALSE(link_features.at(9).has_beam_idx());
EXPECT_FALSE(link_features.at(10).has_feature_value());
EXPECT_FALSE(link_features.at(10).has_batch_idx());
EXPECT_FALSE(link_features.at(10).has_beam_idx());
EXPECT_FALSE(link_features.at(11).has_feature_value());
EXPECT_FALSE(link_features.at(11).has_batch_idx());
EXPECT_FALSE(link_features.at(11).has_beam_idx());
}
TEST_F(SyntaxNetComponentTest, ExportsRawLinkFeatures) {
// Create an empty input batch and beam vector to initialize the parser.
Sentence sentence_0;
TextFormat::ParseFromString(kSentence0, &sentence_0);
string sentence_0_str;
sentence_0.SerializeToString(&sentence_0_str);
Sentence sentence_1;
TextFormat::ParseFromString(kSentence1, &sentence_1);
string sentence_1_str;
sentence_1.SerializeToString(&sentence_1_str);
constexpr int kBeamSize = 3;
auto test_parser =
CreateParserWithBeamSize(kBeamSize, {}, {sentence_0_str, sentence_1_str});
// There are 93 possible transitions for any given state. Create a transition
// array with a score of 10.0 for each transition.
constexpr int kBatchSize = 2;
constexpr int kNumPossibleTransitions = 93;
constexpr float kTransitionValue = 10.0;
float transition_matrix[kNumPossibleTransitions * kBeamSize * kBatchSize];
for (int i = 0; i < kNumPossibleTransitions * kBeamSize * kBatchSize; ++i) {
transition_matrix[i] = kTransitionValue;
}
// Advance twice, so that the underlying parser fills the beam.
test_parser->AdvanceFromPrediction(
transition_matrix, kNumPossibleTransitions * kBeamSize * kBatchSize);
test_parser->AdvanceFromPrediction(
transition_matrix, kNumPossibleTransitions * kBeamSize * kBatchSize);
// Get and check the raw link features.
constexpr int kNumLinkFeatures = 2;
auto link_features = test_parser->GetRawLinkFeatures(0);
EXPECT_EQ(link_features.size(), kBeamSize * kBatchSize * kNumLinkFeatures);
// These should index into batch 0.
EXPECT_EQ(link_features.at(0).feature_value(), -1);
EXPECT_EQ(link_features.at(0).batch_idx(), 0);
EXPECT_EQ(link_features.at(0).beam_idx(), 0);
EXPECT_EQ(link_features.at(1).feature_value(), -2);
EXPECT_EQ(link_features.at(1).batch_idx(), 0);
EXPECT_EQ(link_features.at(1).beam_idx(), 0);
EXPECT_EQ(link_features.at(2).feature_value(), -1);
EXPECT_EQ(link_features.at(2).batch_idx(), 0);
EXPECT_EQ(link_features.at(2).beam_idx(), 1);
EXPECT_EQ(link_features.at(3).feature_value(), -2);
EXPECT_EQ(link_features.at(3).batch_idx(), 0);
EXPECT_EQ(link_features.at(3).beam_idx(), 1);
EXPECT_EQ(link_features.at(4).feature_value(), -1);
EXPECT_EQ(link_features.at(4).batch_idx(), 0);
EXPECT_EQ(link_features.at(4).beam_idx(), 2);
EXPECT_EQ(link_features.at(5).feature_value(), -2);
EXPECT_EQ(link_features.at(5).batch_idx(), 0);
EXPECT_EQ(link_features.at(5).beam_idx(), 2);
// These should index into batch 1.
EXPECT_EQ(link_features.at(6).feature_value(), -1);
EXPECT_EQ(link_features.at(6).batch_idx(), 1);
EXPECT_EQ(link_features.at(6).beam_idx(), 0);
EXPECT_EQ(link_features.at(7).feature_value(), -2);
EXPECT_EQ(link_features.at(7).batch_idx(), 1);
EXPECT_EQ(link_features.at(7).beam_idx(), 0);
EXPECT_EQ(link_features.at(8).feature_value(), -1);
EXPECT_EQ(link_features.at(8).batch_idx(), 1);
EXPECT_EQ(link_features.at(8).beam_idx(), 1);
EXPECT_EQ(link_features.at(9).feature_value(), -2);
EXPECT_EQ(link_features.at(9).batch_idx(), 1);
EXPECT_EQ(link_features.at(9).beam_idx(), 1);
EXPECT_EQ(link_features.at(10).feature_value(), -1);
EXPECT_EQ(link_features.at(10).batch_idx(), 1);
EXPECT_EQ(link_features.at(10).beam_idx(), 2);
EXPECT_EQ(link_features.at(11).feature_value(), -2);
EXPECT_EQ(link_features.at(11).batch_idx(), 1);
EXPECT_EQ(link_features.at(11).beam_idx(), 2);
}
TEST_F(SyntaxNetComponentTest, AdvancesFromOracleWithTracing) {
// Create an empty input batch and beam vector to initialize the parser.
Sentence sentence_0;
TextFormat::ParseFromString(kSentence0, &sentence_0);
string sentence_0_str;
sentence_0.SerializeToString(&sentence_0_str);
constexpr int kBeamSize = 1;
auto test_parser = CreateParserWithBeamSize(kBeamSize, {}, {sentence_0_str});
test_parser->InitializeTracing();
constexpr int kNumTokensInSentence = 3;
// The master spec will initialize a parser, so expect 2*N transitions.
constexpr int kExpectedNumTransitions = kNumTokensInSentence * 2;
constexpr int kFixedFeatureChannels = 1;
for (int i = 0; i < kExpectedNumTransitions; ++i) {
EXPECT_FALSE(test_parser->IsTerminal());
vector<int32> indices;
auto indices_fn = [&indices](int size) {
indices.resize(size);
return indices.data();
};
vector<int64> ids;
auto ids_fn = [&ids](int size) {
ids.resize(size);
return ids.data();
};
vector<float> weights;
auto weights_fn = [&weights](int size) {
weights.resize(size);
return weights.data();
};
for (int j = 0; j < kFixedFeatureChannels; ++j) {
test_parser->GetFixedFeatures(indices_fn, ids_fn, weights_fn, j);
}
auto features = test_parser->GetRawLinkFeatures(0);
// Make some fake translations to test visualization.
for (int j = 0; j < features.size(); ++j) {
features[j].set_step_idx(j < i ? j : -1);
}
test_parser->AddTranslatedLinkFeaturesToTrace(features, 0);
test_parser->AdvanceFromOracle();
}
// At this point, the test parser should be terminal.
EXPECT_TRUE(test_parser->IsTerminal());
// TODO(googleuser): Add EXPECT_EQ here instead of printing.
std::vector<std::vector<ComponentTrace>> traces =
test_parser->GetTraceProtos();
for (auto &batch_trace : traces) {
for (auto &trace : batch_trace) {
LOG(INFO) << "trace:" << std::endl << trace.DebugString();
}
}
}
TEST_F(SyntaxNetComponentTest, NoTracingDropsFeatureNames) {
// Create an empty input batch and beam vector to initialize the parser.
Sentence sentence_0;
TextFormat::ParseFromString(kSentence0, &sentence_0);
string sentence_0_str;
sentence_0.SerializeToString(&sentence_0_str);
constexpr int kBeamSize = 1;
const auto test_parser =
CreateParserWithBeamSize(kBeamSize, {}, {sentence_0_str});
const auto link_features = test_parser->GetRawLinkFeatures(0);
// The fml associated with the channel is "stack.focus stack(1).focus".
// Both features should lack the feature_name field.
EXPECT_EQ(link_features.size(), 2);
EXPECT_FALSE(link_features.at(0).has_feature_name());
EXPECT_FALSE(link_features.at(1).has_feature_name());
}
TEST_F(SyntaxNetComponentTest, TracingOutputsFeatureNames) {
// Create an empty input batch and beam vector to initialize the parser.
Sentence sentence_0;
TextFormat::ParseFromString(kSentence0, &sentence_0);
string sentence_0_str;
sentence_0.SerializeToString(&sentence_0_str);
constexpr int kBeamSize = 1;
auto test_parser = CreateParserWithBeamSize(kBeamSize, {}, {sentence_0_str});
test_parser->InitializeTracing();
const auto link_features = test_parser->GetRawLinkFeatures(0);
// The fml associated with the channel is "stack.focus stack(1).focus".
EXPECT_EQ(link_features.size(), 2);
EXPECT_EQ(link_features.at(0).feature_name(), "stack.focus");
EXPECT_EQ(link_features.at(1).feature_name(), "stack(1).focus");
}
} // namespace dragnn
} // namespace syntaxnet
#include "dragnn/components/syntaxnet/syntaxnet_link_feature_extractor.h"
#include "tensorflow/core/platform/logging.h"
namespace syntaxnet {
namespace dragnn {
void SyntaxNetLinkFeatureExtractor::Setup(TaskContext *context) {
ParserEmbeddingFeatureExtractor::Setup(context);
if (NumEmbeddings() > 0) {
channel_sources_ = utils::Split(
context->Get(
tensorflow::strings::StrCat(ArgPrefix(), "_", "source_components"),
""),
';');
channel_layers_ = utils::Split(
context->Get(
tensorflow::strings::StrCat(ArgPrefix(), "_", "source_layers"), ""),
';');
channel_translators_ = utils::Split(
context->Get(
tensorflow::strings::StrCat(ArgPrefix(), "_", "source_translators"),
""),
';');
}
CHECK_EQ(channel_sources_.size(), NumEmbeddings());
CHECK_EQ(channel_layers_.size(), NumEmbeddings());
CHECK_EQ(channel_translators_.size(), NumEmbeddings());
}
void SyntaxNetLinkFeatureExtractor::AddLinkedFeatureChannelProtos(
ComponentSpec *spec) const {
for (int embedding_idx = 0; embedding_idx < NumEmbeddings();
++embedding_idx) {
LinkedFeatureChannel *channel = spec->add_linked_feature();
channel->set_name(embedding_name(embedding_idx));
channel->set_fml(embedding_fml()[embedding_idx]);
channel->set_embedding_dim(EmbeddingDims(embedding_idx));
channel->set_size(FeatureSize(embedding_idx));
channel->set_source_layer(channel_layers_[embedding_idx]);
channel->set_source_component(channel_sources_[embedding_idx]);
channel->set_source_translator(channel_translators_[embedding_idx]);
}
}
} // namespace dragnn
} // namespace syntaxnet
#ifndef NLP_SAFT_OPENSOURCE_DRAGNN_COMPONENTS_SYNTAXNET_SYNTAXNET_LINK_FEATURE_EXTRACTOR_H_
#define NLP_SAFT_OPENSOURCE_DRAGNN_COMPONENTS_SYNTAXNET_SYNTAXNET_LINK_FEATURE_EXTRACTOR_H_
#include <string>
#include <vector>
#include "dragnn/protos/spec.pb.h"
#include "syntaxnet/embedding_feature_extractor.h"
#include "syntaxnet/parser_state.h"
#include "syntaxnet/parser_transitions.h"
#include "syntaxnet/task_context.h"
namespace syntaxnet {
namespace dragnn {
// Provides feature extraction for linked features in the
// WrapperParserComponent. This re-ues the EmbeddingFeatureExtractor
// architecture to get another set of feature extractors. Note that we should
// ignore predicate maps here, and we don't care about the vocabulary size
// because all the feature values will be used for translation, but this means
// we can configure the extractor from the GCL using the standard
// neurosis-lib.wf syntax.
//
// Because it uses a different prefix, it can be executed in the same wf.stage
// as the regular fixed extractor.
class SyntaxNetLinkFeatureExtractor : public ParserEmbeddingFeatureExtractor {
public:
SyntaxNetLinkFeatureExtractor() : ParserEmbeddingFeatureExtractor("link") {}
~SyntaxNetLinkFeatureExtractor() override {}
const string ArgPrefix() const override { return "link"; }
// Parses the TaskContext to get additional information like target layers,
// etc.
void Setup(TaskContext *context) override;
// Called during InitComponentProtoTask to add the specification from the
// wrapped feature extractor as LinkedFeatureChannel protos.
void AddLinkedFeatureChannelProtos(ComponentSpec *spec) const;
private:
// Source component names for each channel.
std::vector<string> channel_sources_;
// Source layer names for each channel.
std::vector<string> channel_layers_;
// Source translator name for each channel.
std::vector<string> channel_translators_;
};
} // namespace dragnn
} // namespace syntaxnet
#endif // NLP_SAFT_OPENSOURCE_DRAGNN_COMPONENTS_SYNTAXNET_SYNTAXNET_LINK_FEATURE_EXTRACTOR_H_
#include "dragnn/components/syntaxnet/syntaxnet_link_feature_extractor.h"
#include <string>
#include "dragnn/core/test/generic.h"
#include "dragnn/protos/spec.pb.h"
#include "syntaxnet/task_context.h"
#include "tensorflow/core/platform/test.h"
using syntaxnet::test::EqualsProto;
namespace syntaxnet {
namespace dragnn {
class ExportSpecTest : public ::testing::Test {
public:
};
TEST_F(ExportSpecTest, WritesChannelSpec) {
TaskContext context;
context.SetParameter("neurosis_feature_syntax_version", "2");
context.SetParameter("link_features", "input.focus;stack.focus");
context.SetParameter("link_embedding_names", "tagger;parser");
context.SetParameter("link_predicate_maps", "none;none");
context.SetParameter("link_embedding_dims", "16;16");
context.SetParameter("link_source_components", "tagger;parser");
context.SetParameter("link_source_layers", "hidden0;lstm");
context.SetParameter("link_source_translators", "token;last_action");
SyntaxNetLinkFeatureExtractor link_features;
link_features.Setup(&context);
link_features.Init(&context);
ComponentSpec spec;
link_features.AddLinkedFeatureChannelProtos(&spec);
const string expected_spec_str = R"(
linked_feature {
name: "tagger"
fml: "input.focus"
embedding_dim: 16
size: 1
source_component: "tagger"
source_translator: "token"
source_layer: "hidden0"
}
linked_feature {
name: "parser"
fml: "stack.focus"
embedding_dim: 16
size: 1
source_component: "parser"
source_translator: "last_action"
source_layer: "lstm"
}
)";
ComponentSpec expected_spec;
TextFormat::ParseFromString(expected_spec_str, &expected_spec);
EXPECT_THAT(spec, EqualsProto(expected_spec));
}
} // namespace dragnn
} // namespace syntaxnet
#include "dragnn/components/syntaxnet/syntaxnet_transition_state.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/logging.h"
namespace syntaxnet {
namespace dragnn {
SyntaxNetTransitionState::SyntaxNetTransitionState(
std::unique_ptr<ParserState> parser_state, SyntaxNetSentence *sentence)
: parser_state_(std::move(parser_state)), sentence_(sentence) {
score_ = 0;
current_beam_index_ = -1;
parent_beam_index_ = 0;
step_for_token_.resize(sentence->sentence()->token_size(), -1);
parent_for_token_.resize(sentence->sentence()->token_size(), -1);
parent_step_for_token_.resize(sentence->sentence()->token_size(), -1);
}
void SyntaxNetTransitionState::Init(const TransitionState &parent) {
score_ = parent.GetScore();
parent_beam_index_ = parent.GetBeamIndex();
}
std::unique_ptr<SyntaxNetTransitionState> SyntaxNetTransitionState::Clone()
const {
// Create a new state from a clone of the underlying parser state.
std::unique_ptr<ParserState> cloned_state(parser_state_->Clone());
std::unique_ptr<SyntaxNetTransitionState> new_state(
new SyntaxNetTransitionState(std::move(cloned_state), sentence_));
// Copy relevant data members and set non-copied ones to flag values.
new_state->score_ = score_;
new_state->current_beam_index_ = current_beam_index_;
new_state->parent_beam_index_ = parent_beam_index_;
new_state->step_for_token_ = step_for_token_;
new_state->parent_step_for_token_ = parent_step_for_token_;
new_state->parent_for_token_ = parent_for_token_;
// Copy trace if it exists.
if (trace_) {
new_state->trace_.reset(new ComponentTrace(*trace_));
}
return new_state;
}
const int SyntaxNetTransitionState::ParentBeamIndex() const {
return parent_beam_index_;
}
const int SyntaxNetTransitionState::GetBeamIndex() const {
return current_beam_index_;
}
void SyntaxNetTransitionState::SetBeamIndex(const int index) {
current_beam_index_ = index;
}
const float SyntaxNetTransitionState::GetScore() const { return score_; }
void SyntaxNetTransitionState::SetScore(const float score) { score_ = score; }
string SyntaxNetTransitionState::HTMLRepresentation() const {
// Crude HTML string showing the stack and the word on the input.
string html = "Stack: ";
for (int i = parser_state_->StackSize() - 1; i >= 0; --i) {
const int word_idx = parser_state_->Stack(i);
if (word_idx >= 0) {
tensorflow::strings::StrAppend(
&html, parser_state_->GetToken(word_idx).word(), " ");
}
}
tensorflow::strings::StrAppend(&html, "| Input: ");
const int word_idx = parser_state_->Input(0);
if (word_idx >= 0) {
tensorflow::strings::StrAppend(
&html, parser_state_->GetToken(word_idx).word(), " ");
}
return html;
}
} // namespace dragnn
} // namespace syntaxnet
#ifndef NLP_SAFT_OPENSOURCE_DRAGNN_COMPONENTS_SYNTAXNET_SYNTAXNET_TRANSITION_STATE_H_
#define NLP_SAFT_OPENSOURCE_DRAGNN_COMPONENTS_SYNTAXNET_SYNTAXNET_TRANSITION_STATE_H_
#include <vector>
#include "dragnn/core/interfaces/cloneable_transition_state.h"
#include "dragnn/core/interfaces/transition_state.h"
#include "dragnn/io/syntaxnet_sentence.h"
#include "dragnn/protos/trace.pb.h"
#include "syntaxnet/base.h"
#include "syntaxnet/parser_state.h"
namespace syntaxnet {
namespace dragnn {
class SyntaxNetTransitionState
: public CloneableTransitionState<SyntaxNetTransitionState> {
public:
// Create a SyntaxNetTransitionState to wrap this nlp_saft::ParserState.
SyntaxNetTransitionState(std::unique_ptr<ParserState> parser_state,
SyntaxNetSentence *sentence);
// Initialize this TransitionState from a previous TransitionState. The
// ParentBeamIndex is the location of that previous TransitionState in the
// provided beam.
void Init(const TransitionState &parent) override;
// Produces a new state with the same backing data as this state.
std::unique_ptr<SyntaxNetTransitionState> Clone() const override;
// Return the beam index of the state passed into the initializer of this
// TransitionState.
const int ParentBeamIndex() const override;
// Get the current beam index for this state.
const int GetBeamIndex() const override;
// Set the current beam index for this state.
void SetBeamIndex(const int index) override;
// Get the score associated with this transition state.
const float GetScore() const override;
// Set the score associated with this transition state.
void SetScore(const float score) override;
// Depicts this state as an HTML-language string.
string HTMLRepresentation() const override;
// **** END INHERITED INTERFACE ****
// TODO(googleuser): Make these comments actually mean something.
// Data accessor.
int step_for_token(int token) {
if (token < 0 || token >= step_for_token_.size()) {
return -1;
} else {
return step_for_token_.at(token);
}
}
// Data setter.
void set_step_for_token(int token, int step) {
step_for_token_.insert(step_for_token_.begin() + token, step);
}
// Data accessor.
int parent_step_for_token(int token) {
if (token < 0 || token >= step_for_token_.size()) {
return -1;
} else {
return parent_step_for_token_.at(token);
}
}
// Data setter.
void set_parent_step_for_token(int token, int parent_step) {
parent_step_for_token_.insert(parent_step_for_token_.begin() + token,
parent_step);
}
// Data accessor.
int parent_for_token(int token) {
if (token < 0 || token >= step_for_token_.size()) {
return -1;
} else {
return parent_for_token_.at(token);
}
}
// Data setter.
void set_parent_for_token(int token, int parent) {
parent_for_token_.insert(parent_for_token_.begin() + token, parent);
}
// Accessor for the underlying nlp_saft::ParserState.
ParserState *parser_state() { return parser_state_.get(); }
// Accessor for the underlying sentence object.
SyntaxNetSentence *sentence() { return sentence_; }
ComponentTrace *mutable_trace() {
CHECK(trace_) << "Trace is not initialized";
return trace_.get();
}
void set_trace(std::unique_ptr<ComponentTrace> trace) {
trace_ = std::move(trace);
}
private:
// Underlying ParserState object that is being wrapped.
std::unique_ptr<ParserState> parser_state_;
// Sentence object that is being examined with this state.
SyntaxNetSentence *sentence_;
// The current score of this state.
float score_;
// The current beam index of this state.
int current_beam_index_;
// The parent beam index for this state.
int parent_beam_index_;
// Maintains a list of which steps in the history correspond to
// representations for each of the tokens on the stack.
std::vector<int> step_for_token_;
// Maintains a list of which steps in the history correspond to the actions
// that assigned a parent for tokens when reduced.
std::vector<int> parent_step_for_token_;
// Maintain the parent index of a token in the system.
std::vector<int> parent_for_token_;
// Trace of the history to produce this state.
std::unique_ptr<ComponentTrace> trace_;
};
} // namespace dragnn
} // namespace syntaxnet
#endif // NLP_SAFT_OPENSOURCE_DRAGNN_COMPONENTS_SYNTAXNET_SYNTAXNET_TRANSITION_STATE_H_
#include "dragnn/components/syntaxnet/syntaxnet_transition_state.h"
#include "dragnn/components/syntaxnet/syntaxnet_component.h"
#include "dragnn/core/input_batch_cache.h"
#include "dragnn/core/test/generic.h"
#include "dragnn/core/test/mock_transition_state.h"
#include "dragnn/io/sentence_input_batch.h"
#include "dragnn/protos/spec.pb.h"
#include "syntaxnet/sentence.pb.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/platform/test.h"
// This test suite is intended to validate the contracts that the DRAGNN
// system expects from all transition state subclasses. Developers creating
// new TransitionStates should copy this test and modify it as necessary,
// using it to ensure their state conforms to DRAGNN expectations.
namespace syntaxnet {
namespace dragnn {
namespace {
const char kSentence0[] = R"(
token {
word: "Sentence" start: 0 end: 7 tag: "NN" category: "NOUN" label: "ROOT"
break_level: NO_BREAK
}
token {
word: "0" start: 9 end: 9 head: 0 tag: "CD" category: "NUM" label: "num"
break_level: SPACE_BREAK
}
token {
word: "." start: 10 end: 10 head: 0 tag: "." category: "." label: "punct"
break_level: NO_BREAK
}
)";
} // namespace
using testing::Return;
class SyntaxNetTransitionStateTest : public ::testing::Test {
public:
std::unique_ptr<SyntaxNetTransitionState> CreateState() {
// Get the master spec proto from the test data directory.
MasterSpec master_spec;
string file_name = tensorflow::io::JoinPath(
test::GetTestDataPrefix(), "dragnn/components/syntaxnet/testdata",
"master_spec.textproto");
TF_CHECK_OK(tensorflow::ReadTextProto(tensorflow::Env::Default(), file_name,
&master_spec));
// Get all the resource protos from the test data directory.
for (Resource &resource :
*(master_spec.mutable_component(0)->mutable_resource())) {
resource.mutable_part(0)->set_file_pattern(tensorflow::io::JoinPath(
test::GetTestDataPrefix(), "dragnn/components/syntaxnet/testdata",
resource.part(0).file_pattern()));
}
// Create an empty input batch and beam vector to initialize the parser.
Sentence sentence_0;
TextFormat::ParseFromString(kSentence0, &sentence_0);
string sentence_0_str;
sentence_0.SerializeToString(&sentence_0_str);
data_.reset(new InputBatchCache(sentence_0_str));
SentenceInputBatch *sentences = data_->GetAs<SentenceInputBatch>();
// Create a parser comoponent that will generate a parser state for this
// test.
SyntaxNetComponent component;
component.InitializeComponent(*(master_spec.mutable_component(0)));
std::vector<std::vector<const TransitionState *>> states;
constexpr int kBeamSize = 1;
component.InitializeData(states, kBeamSize, data_.get());
// Get a transition state from the component.
std::unique_ptr<SyntaxNetTransitionState> test_state =
component.CreateState(&(sentences->data()->at(0)));
return test_state;
}
std::unique_ptr<InputBatchCache> data_;
};
// Validates the consistency of the beam index setter and getter.
TEST_F(SyntaxNetTransitionStateTest, CanSetAndGetBeamIndex) {
// Create and initialize a test state.
MockTransitionState mock_state;
auto test_state = CreateState();
test_state->Init(mock_state);
constexpr int kOldBeamIndex = 12;
test_state->SetBeamIndex(kOldBeamIndex);
EXPECT_EQ(test_state->GetBeamIndex(), kOldBeamIndex);
constexpr int kNewBeamIndex = 7;
test_state->SetBeamIndex(kNewBeamIndex);
EXPECT_EQ(test_state->GetBeamIndex(), kNewBeamIndex);
}
// Validates the consistency of the score setter and getter.
TEST_F(SyntaxNetTransitionStateTest, CanSetAndGetScore) {
// Create and initialize a test state.
MockTransitionState mock_state;
auto test_state = CreateState();
test_state->Init(mock_state);
constexpr float kOldScore = 12.1;
test_state->SetScore(kOldScore);
EXPECT_EQ(test_state->GetScore(), kOldScore);
constexpr float kNewScore = 7.2;
test_state->SetScore(kNewScore);
EXPECT_EQ(test_state->GetScore(), kNewScore);
}
// This test ensures that the initializing state's current index is saved
// as the parent beam index of the state being initialized.
TEST_F(SyntaxNetTransitionStateTest, ReportsParentBeamIndex) {
// Create a mock transition state that wil report a specific current index.
// This index should become the parent state index for the test state.
MockTransitionState mock_state;
constexpr int kParentBeamIndex = 1138;
EXPECT_CALL(mock_state, GetBeamIndex())
.WillRepeatedly(Return(kParentBeamIndex));
auto test_state = CreateState();
test_state->Init(mock_state);
EXPECT_EQ(test_state->ParentBeamIndex(), kParentBeamIndex);
}
// This test ensures that the initializing state's current score is saved
// as the current score of the state being initialized.
TEST_F(SyntaxNetTransitionStateTest, InitializationCopiesParentScore) {
// Create a mock transition state that wil report a specific current index.
// This index should become the parent state index for the test state.
MockTransitionState mock_state;
constexpr float kParentScore = 24.12;
EXPECT_CALL(mock_state, GetScore()).WillRepeatedly(Return(kParentScore));
auto test_state = CreateState();
test_state->Init(mock_state);
EXPECT_EQ(test_state->GetScore(), kParentScore);
}
// This test ensures that calling Clone maintains the state data (parent beam
// index, beam index, score, etc.) of the state that was cloned.
TEST_F(SyntaxNetTransitionStateTest, CloningMaintainsState) {
// Create and initialize the state->
MockTransitionState mock_state;
constexpr int kParentBeamIndex = 1138;
EXPECT_CALL(mock_state, GetBeamIndex())
.WillRepeatedly(Return(kParentBeamIndex));
auto test_state = CreateState();
test_state->Init(mock_state);
// Validate the internal state of the test state.
constexpr float kOldScore = 20.0;
test_state->SetScore(kOldScore);
EXPECT_EQ(test_state->GetScore(), kOldScore);
constexpr int kOldBeamIndex = 12;
test_state->SetBeamIndex(kOldBeamIndex);
EXPECT_EQ(test_state->GetBeamIndex(), kOldBeamIndex);
auto clone = test_state->Clone();
// The clone should have identical state to the old state.
EXPECT_EQ(clone->ParentBeamIndex(), kParentBeamIndex);
EXPECT_EQ(clone->GetScore(), kOldScore);
EXPECT_EQ(clone->GetBeamIndex(), kOldBeamIndex);
}
// Validates the consistency of the step_for_token setter and getter.
TEST_F(SyntaxNetTransitionStateTest, CanSetAndGetStepForToken) {
// Create and initialize a test state.
MockTransitionState mock_state;
auto test_state = CreateState();
test_state->Init(mock_state);
constexpr int kStepForTokenZero = 12;
constexpr int kStepForTokenTwo = 34;
test_state->set_step_for_token(0, kStepForTokenZero);
test_state->set_step_for_token(2, kStepForTokenTwo);
// Expect that the set tokens return values and the unset steps return the
// default.
constexpr int kDefaultValue = -1;
EXPECT_EQ(kStepForTokenZero, test_state->step_for_token(0));
EXPECT_EQ(kDefaultValue, test_state->step_for_token(1));
EXPECT_EQ(kStepForTokenTwo, test_state->step_for_token(2));
// Expect that out of bound accesses will return the default. (There are only
// 3 tokens in the backing sentence, so token 3 and greater are out of bound.)
EXPECT_EQ(kDefaultValue, test_state->step_for_token(-1));
EXPECT_EQ(kDefaultValue, test_state->step_for_token(3));
}
// Validates the consistency of the parent_step_for_token setter and getter.
TEST_F(SyntaxNetTransitionStateTest, CanSetAndGetParentStepForToken) {
// Create and initialize a test state.
MockTransitionState mock_state;
auto test_state = CreateState();
test_state->Init(mock_state);
constexpr int kStepForTokenZero = 12;
constexpr int kStepForTokenTwo = 34;
test_state->set_parent_step_for_token(0, kStepForTokenZero);
test_state->set_parent_step_for_token(2, kStepForTokenTwo);
// Expect that the set tokens return values and the unset steps return the
// default.
constexpr int kDefaultValue = -1;
EXPECT_EQ(kStepForTokenZero, test_state->parent_step_for_token(0));
EXPECT_EQ(kDefaultValue, test_state->parent_step_for_token(1));
EXPECT_EQ(kStepForTokenTwo, test_state->parent_step_for_token(2));
// Expect that out of bound accesses will return the default. (There are only
// 3 tokens in the backing sentence, so token 3 and greater are out of bound.)
EXPECT_EQ(kDefaultValue, test_state->parent_step_for_token(-1));
EXPECT_EQ(kDefaultValue, test_state->parent_step_for_token(3));
}
// Validates the consistency of the parent_for_token setter and getter.
TEST_F(SyntaxNetTransitionStateTest, CanSetAndGetParentForToken) {
// Create and initialize a test state.
MockTransitionState mock_state;
auto test_state = CreateState();
test_state->Init(mock_state);
constexpr int kParentForTokenZero = 12;
constexpr int kParentForTokenTwo = 34;
test_state->set_parent_for_token(0, kParentForTokenZero);
test_state->set_parent_for_token(2, kParentForTokenTwo);
// Expect that the set tokens return values and the unset steps return the
// default.
constexpr int kDefaultValue = -1;
EXPECT_EQ(kParentForTokenZero, test_state->parent_for_token(0));
EXPECT_EQ(kDefaultValue, test_state->parent_for_token(1));
EXPECT_EQ(kParentForTokenTwo, test_state->parent_for_token(2));
// Expect that out of bound accesses will return the default. (There are only
// 3 tokens in the backing sentence, so token 3 and greater are out of bound.)
EXPECT_EQ(kDefaultValue, test_state->parent_for_token(-1));
EXPECT_EQ(kDefaultValue, test_state->parent_for_token(3));
}
// Validates the consistency of trace proto setter/getter.
TEST_F(SyntaxNetTransitionStateTest, CanSetAndGetTrace) {
// Create and initialize a test state.
MockTransitionState mock_state;
auto test_state = CreateState();
test_state->Init(mock_state);
const string kTestComponentName = "test";
std::unique_ptr<ComponentTrace> trace;
trace.reset(new ComponentTrace());
trace->set_name(kTestComponentName);
test_state->set_trace(std::move(trace));
EXPECT_EQ(trace.get(), nullptr);
EXPECT_EQ(test_state->mutable_trace()->name(), kTestComponentName);
// Should be preserved when cloing.
auto cloned_state = test_state->Clone();
EXPECT_EQ(cloned_state->mutable_trace()->name(), kTestComponentName);
EXPECT_EQ(test_state->mutable_trace()->name(), kTestComponentName);
}
} // namespace dragnn
} // namespace syntaxnet
component {
name: "parser"
transition_system {
registered_name: "arc-standard"
}
resource {
name: 'label-map'
part {
file_pattern: 'syntaxnet-tagger.label-map'
file_format: 'text'
}
}
resource {
name: 'tag-map'
part {
file_pattern: 'syntaxnet-tagger.tag-map'
file_format: 'text'
}
}
fixed_feature {
name: "tags"
fml: "input.tag input(1).tag"
embedding_dim: 32
vocabulary_size: 46
size: 2
predicate_map: "hashed"
}
fixed_feature {
name: "tags"
fml: "input(-1).tag input.tag input(1).tag"
embedding_dim: 32
vocabulary_size: 46
size: 3
predicate_map: "hashed"
}
linked_feature {
name: "recurrent_stack"
fml: "stack.focus stack(1).focus"
embedding_dim: 32
size: 2
source_component: "parser"
source_translator: "identity"
source_layer: "hidden_0"
}
backend {
registered_name: "SyntaxNetComponent"
}
}
46
punct 243160
prep 194627
pobj 186958
det 170592
nsubj 144821
nn 144800
amod 117242
ROOT 90592
dobj 88551
aux 76523
advmod 72893
conj 59384
cc 57532
num 36350
poss 35117
dep 34986
ccomp 29470
cop 25991
mark 25141
xcomp 25111
rcmod 16234
auxpass 15740
advcl 14996
possessive 14866
nsubjpass 14133
pcomp 12488
appos 11112
partmod 11106
neg 11090
number 10658
prt 7123
quantmod 6653
tmod 5418
infmod 5134
npadvmod 3213
parataxis 3012
mwe 2793
expl 2712
iobj 1642
acomp 1632
discourse 1381
csubj 1225
predet 1160
preconj 749
goeswith 146
csubjpass 41
component {
name: "tagger"
num_actions : 49
transition_system {
registered_name: "tagger"
parameters {
key: "join_category_to_pos"
value: "true"
}
}
resource {
name: "tag-map"
part {
file_pattern: "TESTDATA/syntaxnet-tagger.tag-map"
file_format: "text"
}
}
resource {
name: "word-map"
part {
file_pattern: "TESTDATA/syntaxnet-tagger.word-map"
file_format: "text"
}
}
resource {
name: "label-map"
part {
file_pattern: "TESTDATA/syntaxnet-tagger.label-map"
file_format: "text"
}
}
fixed_feature {
name: "words"
fml: "input(-1).word input(-2).word input(-3).word input.word input(1).word input(2).word input(3).word"
embedding_dim: 64
vocabulary_size: 39397
size: 7
}
fixed_feature {
name: "words"
fml: "input(-3).word input.word input(1).word input(2).word input(3).word"
embedding_dim: 64
vocabulary_size: 39397
size: 5
}
linked_feature {
name: "rnn"
fml: "stack.focus"
embedding_dim: 32
size: 1
source_component: "tagger"
source_translator: "shift-reduce-step"
source_layer: "layer_0"
}
backend {
registered_name: "SyntaxNetComponent"
}
network_unit {
registered_name: 'feed-forward'
parameters {
key: 'hidden_layer_sizes'
value: '64'
}
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment