Commit 4364390a authored by Ivan Bogatyy's avatar Ivan Bogatyy Committed by calberti
Browse files

Release DRAGNN bulk networks (#2785)

* Release DRAGNN bulk networks
parent 638fd759
# Java baseimage, for Bazel. FROM ubuntu:16.10
FROM openjdk:8
ENV SYNTAXNETDIR=/opt/tensorflow PATH=$PATH:/root/bin ENV SYNTAXNETDIR=/opt/tensorflow PATH=$PATH:/root/bin
...@@ -21,13 +20,15 @@ RUN mkdir -p $SYNTAXNETDIR \ ...@@ -21,13 +20,15 @@ RUN mkdir -p $SYNTAXNETDIR \
libopenblas-dev \ libopenblas-dev \
libpng-dev \ libpng-dev \
libxft-dev \ libxft-dev \
patch \ openjdk-8-jdk \
python-dev \ python-dev \
python-mock \ python-mock \
python-pip \ python-pip \
python2.7 \ python2.7 \
swig \ swig \
unzip \
vim \ vim \
wget \
zlib1g-dev \ zlib1g-dev \
&& apt-get clean \ && apt-get clean \
&& (rm -f /var/cache/apt/archives/*.deb \ && (rm -f /var/cache/apt/archives/*.deb \
...@@ -55,7 +56,7 @@ RUN python -m pip install \ ...@@ -55,7 +56,7 @@ RUN python -m pip install \
--py --sys-prefix widgetsnbextension \ --py --sys-prefix widgetsnbextension \
&& rm -rf /root/.cache/pip /tmp/pip* && rm -rf /root/.cache/pip /tmp/pip*
# Installs the latest version of Bazel. # Installs Bazel.
RUN wget --quiet https://github.com/bazelbuild/bazel/releases/download/0.5.4/bazel-0.5.4-installer-linux-x86_64.sh \ RUN wget --quiet https://github.com/bazelbuild/bazel/releases/download/0.5.4/bazel-0.5.4-installer-linux-x86_64.sh \
&& chmod +x bazel-0.5.4-installer-linux-x86_64.sh \ && chmod +x bazel-0.5.4-installer-linux-x86_64.sh \
&& ./bazel-0.5.4-installer-linux-x86_64.sh \ && ./bazel-0.5.4-installer-linux-x86_64.sh \
...@@ -65,13 +66,11 @@ COPY WORKSPACE $SYNTAXNETDIR/syntaxnet/WORKSPACE ...@@ -65,13 +66,11 @@ COPY WORKSPACE $SYNTAXNETDIR/syntaxnet/WORKSPACE
COPY tools/bazel.rc $SYNTAXNETDIR/syntaxnet/tools/bazel.rc COPY tools/bazel.rc $SYNTAXNETDIR/syntaxnet/tools/bazel.rc
COPY tensorflow $SYNTAXNETDIR/syntaxnet/tensorflow COPY tensorflow $SYNTAXNETDIR/syntaxnet/tensorflow
# Workaround solving the PYTHON_BIN_PATH not found problem
ENV PYTHON_BIN_PATH=/usr/bin/python
# Compile common TensorFlow targets, which don't depend on DRAGNN / SyntaxNet # Compile common TensorFlow targets, which don't depend on DRAGNN / SyntaxNet
# source. This makes it more convenient to re-compile DRAGNN / SyntaxNet for # source. This makes it more convenient to re-compile DRAGNN / SyntaxNet for
# development (though not as convenient as the docker-devel scripts). # development (though not as convenient as the docker-devel scripts).
RUN cd $SYNTAXNETDIR/syntaxnet/tensorflow \ RUN cd $SYNTAXNETDIR/syntaxnet/tensorflow \
&& ./configure CPU \ && tensorflow/tools/ci_build/builds/configured CPU \
&& cd $SYNTAXNETDIR/syntaxnet \ && cd $SYNTAXNETDIR/syntaxnet \
&& bazel build -c opt @org_tensorflow//tensorflow:tensorflow_py && bazel build -c opt @org_tensorflow//tensorflow:tensorflow_py
...@@ -92,4 +91,4 @@ EXPOSE 8888 ...@@ -92,4 +91,4 @@ EXPOSE 8888
COPY examples $SYNTAXNETDIR/syntaxnet/examples COPY examples $SYNTAXNETDIR/syntaxnet/examples
# Todo: Move this earlier in the file (don't want to invalidate caches for now). # Todo: Move this earlier in the file (don't want to invalidate caches for now).
CMD /bin/bash -c "bazel-bin/dragnn/tools/oss_notebook_launcher notebook --debug --notebook-dir=/opt/tensorflow/syntaxnet/examples --allow-root" CMD /bin/bash -c "bazel-bin/dragnn/tools/oss_notebook_launcher notebook --debug --notebook-dir=/opt/tensorflow/syntaxnet/examples"
...@@ -23,8 +23,8 @@ This repository is largely divided into two sub-packages: ...@@ -23,8 +23,8 @@ This repository is largely divided into two sub-packages:
[documentation](g3doc/DRAGNN.md), [documentation](g3doc/DRAGNN.md),
[paper](https://arxiv.org/pdf/1703.04474.pdf)** implements Dynamic Recurrent [paper](https://arxiv.org/pdf/1703.04474.pdf)** implements Dynamic Recurrent
Acyclic Graphical Neural Networks (DRAGNN), a framework for building Acyclic Graphical Neural Networks (DRAGNN), a framework for building
multi-task, fully dynamically constructed computation graphs. Practically, we multi-task, fully dynamically constructed computation graphs. Practically,
use DRAGNN to extend our prior work from [Andor et al. we use DRAGNN to extend our prior work from [Andor et al.
(2016)](http://arxiv.org/abs/1603.06042) with end-to-end, deep recurrent (2016)](http://arxiv.org/abs/1603.06042) with end-to-end, deep recurrent
models and to provide a much easier to use interface to SyntaxNet. *DRAGNN models and to provide a much easier to use interface to SyntaxNet. *DRAGNN
is designed first and foremost as a Python library, and therefore much is designed first and foremost as a Python library, and therefore much
...@@ -54,20 +54,47 @@ There are three ways to use SyntaxNet: ...@@ -54,20 +54,47 @@ There are three ways to use SyntaxNet:
### Docker installation ### Docker installation
_This process takes ~10 minutes._
The simplest way to get started with DRAGNN is by loading our Docker container. The simplest way to get started with DRAGNN is by loading our Docker container.
[Here](g3doc/CLOUD.md) is a tutorial for running the DRAGNN container on [Here](g3doc/CLOUD.md) is a tutorial for running the DRAGNN container on
[GCP](https://cloud.google.com) (just as applicable to your own computer). [GCP](https://cloud.google.com) (just as applicable to your own computer).
### Ubuntu 16.10+ binary installation
_This process takes ~5 minutes, but is only compatible with Linux using GNU libc
3.4.22 and above (e.g. Ubuntu 16.10)._
Binary wheel packages are provided for TensorFlow and SyntaxNet. If you do not
need to write new binary TensorFlow ops, these should suffice.
* `apt-get install -y graphviz libgraphviz-dev libopenblas-base libpng16-16
libxft2 python-pip python-mock`
* `pip install pygraphviz
--install-option="--include-path=/usr/include/graphviz"
--install-option="--library-path=/usr/lib/graphviz/"`
* `pip install 'ipython<6.0' protobuf numpy scipy jupyter
syntaxnet-with-tensorflow`
* `python -m jupyter_core.command nbextension enable --py --sys-prefix
widgetsnbextension`
You can test that binary modules can be successfully imported by running,
* `python -c 'import dragnn.python.load_dragnn_cc_impl,
syntaxnet.load_parser_ops'`
### Manual installation ### Manual installation
_This process takes 1-2 hours._
Running and training SyntaxNet/DRAGNN models requires building this package from Running and training SyntaxNet/DRAGNN models requires building this package from
source. You'll need to install: source. You'll need to install:
* python 2.7: * python 2.7:
* Python 3 support is not available yet * Python 3 support is not available yet
* bazel: * bazel 0.5.4:
* Follow the instructions [here](http://bazel.build/docs/install.html) * Follow the instructions [here](http://bazel.build/docs/install.html)
* Alternately, Download bazel <.deb> from * Alternately, Download bazel 0.5.4 <.deb> from
[https://github.com/bazelbuild/bazel/releases](https://github.com/bazelbuild/bazel/releases) [https://github.com/bazelbuild/bazel/releases](https://github.com/bazelbuild/bazel/releases)
for your system configuration. for your system configuration.
* Install it using the command: sudo dpkg -i <.deb file> * Install it using the command: sudo dpkg -i <.deb file>
...@@ -103,9 +130,12 @@ following commands: ...@@ -103,9 +130,12 @@ following commands:
bazel test --linkopt=-headerpad_max_install_names \ bazel test --linkopt=-headerpad_max_install_names \
dragnn/... syntaxnet/... util/utf8/... dragnn/... syntaxnet/... util/utf8/...
``` ```
Bazel should complete reporting all tests passed. Bazel should complete reporting all tests passed.
Now you can install the SyntaxNet and DRAGNN Python modules with the following commands: Now you can install the SyntaxNet and DRAGNN Python modules with the following
commands:
```shell ```shell
mkdir /tmp/syntaxnet_pkg mkdir /tmp/syntaxnet_pkg
bazel-bin/dragnn/tools/build_pip_package --output-dir=/tmp/syntaxnet_pkg bazel-bin/dragnn/tools/build_pip_package --output-dir=/tmp/syntaxnet_pkg
...@@ -116,8 +146,6 @@ Now you can install the SyntaxNet and DRAGNN Python modules with the following c ...@@ -116,8 +146,6 @@ Now you can install the SyntaxNet and DRAGNN Python modules with the following c
To build SyntaxNet with GPU support please refer to the instructions in To build SyntaxNet with GPU support please refer to the instructions in
[issues/248](https://github.com/tensorflow/models/issues/248). [issues/248](https://github.com/tensorflow/models/issues/248).
**Note:** If you are running Docker on OSX, make sure that you have enough **Note:** If you are running Docker on OSX, make sure that you have enough
memory allocated for your Docker VM. memory allocated for your Docker VM.
......
FROM dragnn-oss-test-base:latest
RUN rm -rf \
$SYNTAXNETDIR/syntaxnet/dragnn \
$SYNTAXNETDIR/syntaxnet/syntaxnet \
$SYNTAXNETDIR/syntaxnet/third_party \
$SYNTAXNETDIR/syntaxnet/util/utf8
COPY dragnn $SYNTAXNETDIR/syntaxnet/dragnn
COPY syntaxnet $SYNTAXNETDIR/syntaxnet/syntaxnet
COPY third_party $SYNTAXNETDIR/syntaxnet/third_party
COPY util/utf8 $SYNTAXNETDIR/syntaxnet/util/utf8
FROM ubuntu:16.10
ENV SYNTAXNETDIR=/opt/tensorflow PATH=$PATH:/root/bin
# Install system packages. This doesn't include everything the TensorFlow
# dockerfile specifies, so if anything goes awry, maybe install more packages
# from there. Also, running apt-get clean before further commands will make the
# Docker images smaller.
RUN mkdir -p $SYNTAXNETDIR \
&& cd $SYNTAXNETDIR \
&& apt-get update \
&& apt-get install -y \
file \
git \
graphviz \
libcurl3-dev \
libfreetype6-dev \
libgraphviz-dev \
liblapack-dev \
libopenblas-dev \
libpng-dev \
libxft-dev \
openjdk-8-jdk \
python-dev \
python-mock \
python-pip \
python2.7 \
swig \
unzip \
vim \
wget \
zlib1g-dev \
&& apt-get clean \
&& (rm -f /var/cache/apt/archives/*.deb \
/var/cache/apt/archives/partial/*.deb /var/cache/apt/*.bin || true)
# Install common Python dependencies. Similar to above, remove caches
# afterwards to help keep Docker images smaller.
RUN pip install --ignore-installed pip \
&& python -m pip install numpy \
&& rm -rf /root/.cache/pip /tmp/pip*
RUN python -m pip install \
asciitree \
ipykernel \
jupyter \
matplotlib \
pandas \
protobuf \
scipy \
sklearn \
&& python -m ipykernel.kernelspec \
&& python -m pip install pygraphviz \
--install-option="--include-path=/usr/include/graphviz" \
--install-option="--library-path=/usr/lib/graphviz/" \
&& python -m jupyter_core.command nbextension enable \
--py --sys-prefix widgetsnbextension \
&& rm -rf /root/.cache/pip /tmp/pip*
# Installs Bazel.
RUN wget --quiet https://github.com/bazelbuild/bazel/releases/download/0.5.3/bazel-0.5.3-installer-linux-x86_64.sh \
&& chmod +x bazel-0.5.3-installer-linux-x86_64.sh \
&& JAVA_HOME=/usr/lib/jvm/java-8-openjdk-amd64/ ./bazel-0.5.3-installer-linux-x86_64.sh \
&& rm ./bazel-0.5.3-installer-linux-x86_64.sh
COPY WORKSPACE $SYNTAXNETDIR/syntaxnet/WORKSPACE
COPY tools/bazel.rc $SYNTAXNETDIR/syntaxnet/tools/bazel.rc
# Compile common TensorFlow targets, which don't depend on DRAGNN / SyntaxNet
# source. This makes it more convenient to re-compile DRAGNN / SyntaxNet for
# development (though not as convenient as the docker-devel scripts).
RUN cd $SYNTAXNETDIR/syntaxnet \
&& git clone --branch r1.3 --recurse-submodules https://github.com/tensorflow/tensorflow \
&& cd tensorflow \
# This line removes a bad archive target which causes Tensorflow install
# to fail.
&& sed -i '\@https://github.com/google/protobuf/archive/0b059a3d8a8f8aa40dde7bea55edca4ec5dfea66.tar.gz@d' tensorflow/workspace.bzl \
&& tensorflow/tools/ci_build/builds/configured CPU \\
&& cd $SYNTAXNETDIR/syntaxnet \
&& bazel build -c opt @org_tensorflow//tensorflow:tensorflow_py
# Just copy the code and run tests. The build and test flags differ enough that
# doing a normal build of TensorFlow targets doesn't save much test time.
WORKDIR $SYNTAXNETDIR/syntaxnet
COPY dragnn $SYNTAXNETDIR/syntaxnet/dragnn
COPY syntaxnet $SYNTAXNETDIR/syntaxnet/syntaxnet
COPY third_party $SYNTAXNETDIR/syntaxnet/third_party
COPY util/utf8 $SYNTAXNETDIR/syntaxnet/util/utf8
# Doesn't matter if the tests pass or not, since we're going to re-copy over the
# code.
RUN bazel test -c opt ... || true
# You need to build wheels before building this image. Please consult # You need to build wheels before building this image. Please consult
# docker-devel/README.txt. # docker-devel/README.txt.
# This is the base of the openjdk image.
# #
# It might be more efficient to use a minimal distribution, like Alpine. But # It might be more efficient to use a minimal distribution, like Alpine. But
# the upside of this being popular is that people might already have it. # the upside of this being popular is that people might already have it.
FROM buildpack-deps:jessie-curl FROM ubuntu:16.10
ENV SYNTAXNETDIR=/opt/tensorflow PATH=$PATH:/root/bin ENV SYNTAXNETDIR=/opt/tensorflow PATH=$PATH:/root/bin
...@@ -19,7 +17,7 @@ RUN apt-get update \ ...@@ -19,7 +17,7 @@ RUN apt-get update \
libgraphviz-dev \ libgraphviz-dev \
liblapack3 \ liblapack3 \
libopenblas-base \ libopenblas-base \
libpng12-0 \ libpng16-16 \
libxft2 \ libxft2 \
python-dev \ python-dev \
python-mock \ python-mock \
...@@ -48,11 +46,13 @@ RUN python -m pip install \ ...@@ -48,11 +46,13 @@ RUN python -m pip install \
&& python -m pip install pygraphviz \ && python -m pip install pygraphviz \
--install-option="--include-path=/usr/include/graphviz" \ --install-option="--include-path=/usr/include/graphviz" \
--install-option="--library-path=/usr/lib/graphviz/" \ --install-option="--library-path=/usr/lib/graphviz/" \
&& python -m jupyter_core.command nbextension enable \
--py --sys-prefix widgetsnbextension \
&& rm -rf /root/.cache/pip /tmp/pip* && rm -rf /root/.cache/pip /tmp/pip*
COPY syntaxnet_with_tensorflow-0.2-cp27-none-linux_x86_64.whl $SYNTAXNETDIR/ COPY syntaxnet_with_tensorflow-0.2-cp27-cp27mu-linux_x86_64.whl $SYNTAXNETDIR/
RUN python -m pip install \ RUN python -m pip install \
$SYNTAXNETDIR/syntaxnet_with_tensorflow-0.2-cp27-none-linux_x86_64.whl \ $SYNTAXNETDIR/syntaxnet_with_tensorflow-0.2-cp27-cp27mu-linux_x86_64.whl \
&& rm -rf /root/.cache/pip /tmp/pip* && rm -rf /root/.cache/pip /tmp/pip*
# This makes the IP exposed actually "*"; we'll do host restrictions by passing # This makes the IP exposed actually "*"; we'll do host restrictions by passing
...@@ -63,4 +63,4 @@ EXPOSE 8888 ...@@ -63,4 +63,4 @@ EXPOSE 8888
# This does not need to be compiled, only copied. # This does not need to be compiled, only copied.
COPY examples $SYNTAXNETDIR/syntaxnet/examples COPY examples $SYNTAXNETDIR/syntaxnet/examples
# For some reason, this works if we run it in a bash shell :/ :/ :/ # For some reason, this works if we run it in a bash shell :/ :/ :/
CMD /bin/bash -c "python -m jupyter_core.command notebook --debug --notebook-dir=/opt/tensorflow/syntaxnet/examples" CMD /bin/bash -c "python -m jupyter_core.command notebook --debug --notebook-dir=/opt/tensorflow/syntaxnet/examples --allow-root"
...@@ -43,11 +43,11 @@ Step 3: Building the development image ...@@ -43,11 +43,11 @@ Step 3: Building the development image
First, ensure you have the file First, ensure you have the file
syntaxnet_with_tensorflow-0.2-cp27-none-linux_x86_64.whl syntaxnet_with_tensorflow-0.2-cp27-cp27mu-linux_x86_64.whl
in your working directory, from step 2. Then run, in your working directory, from step 2. Then run,
docker build -t dragnn-oss:latest-minimal -f docker-devel/Dockerfile.min docker build -t dragnn-oss:latest-minimal -f docker-devel/Dockerfile.min .
If the filename changes (e.g. you are on a different architecture), just update If the filename changes (e.g. you are on a different architecture), just update
Dockerfile.min. Dockerfile.min.
......
...@@ -10,7 +10,6 @@ cc_library( ...@@ -10,7 +10,6 @@ cc_library(
"//dragnn/core:component_registry", "//dragnn/core:component_registry",
"//dragnn/core/interfaces:component", "//dragnn/core/interfaces:component",
"//dragnn/core/interfaces:transition_state", "//dragnn/core/interfaces:transition_state",
"//dragnn/io:sentence_input_batch",
"//dragnn/protos:data_proto", "//dragnn/protos:data_proto",
"//syntaxnet:base", "//syntaxnet:base",
], ],
......
...@@ -16,7 +16,6 @@ ...@@ -16,7 +16,6 @@
#include "dragnn/core/component_registry.h" #include "dragnn/core/component_registry.h"
#include "dragnn/core/interfaces/component.h" #include "dragnn/core/interfaces/component.h"
#include "dragnn/core/interfaces/transition_state.h" #include "dragnn/core/interfaces/transition_state.h"
#include "dragnn/io/sentence_input_batch.h"
#include "dragnn/protos/data.pb.h" #include "dragnn/protos/data.pb.h"
#include "syntaxnet/base.h" #include "syntaxnet/base.h"
...@@ -25,7 +24,8 @@ namespace dragnn { ...@@ -25,7 +24,8 @@ namespace dragnn {
namespace { namespace {
// A component that does not create its own transition states; instead, it // A component that does not create its own transition states; instead, it
// simply forwards the states of the previous component. Does not support all // simply forwards the states of the previous component. Requires that some
// previous component has converted the input batch. Does not support all
// methods. Intended for "compute-only" bulk components that only use linked // methods. Intended for "compute-only" bulk components that only use linked
// features, which use only a small subset of DRAGNN functionality. // features, which use only a small subset of DRAGNN functionality.
class StatelessComponent : public Component { class StatelessComponent : public Component {
...@@ -38,8 +38,7 @@ class StatelessComponent : public Component { ...@@ -38,8 +38,7 @@ class StatelessComponent : public Component {
void InitializeData( void InitializeData(
const std::vector<std::vector<const TransitionState *>> &parent_states, const std::vector<std::vector<const TransitionState *>> &parent_states,
int max_beam_size, InputBatchCache *input_data) override { int max_beam_size, InputBatchCache *input_data) override {
// Must use SentenceInputBatch to match SyntaxNetComponent. batch_size_ = input_data->Size();
batch_size_ = input_data->GetAs<SentenceInputBatch>()->data()->size();
beam_size_ = max_beam_size; beam_size_ = max_beam_size;
parent_states_ = parent_states; parent_states_ = parent_states;
...@@ -84,31 +83,34 @@ class StatelessComponent : public Component { ...@@ -84,31 +83,34 @@ class StatelessComponent : public Component {
LOG(FATAL) << "[" << name_ << "] Method not supported"; LOG(FATAL) << "[" << name_ << "] Method not supported";
return nullptr; return nullptr;
} }
void AdvanceFromPrediction(const float transition_matrix[], bool AdvanceFromPrediction(const float *transition_matrix, int num_items,
int matrix_length) override { int num_actions) override {
LOG(FATAL) << "[" << name_ << "] Method not supported"; LOG(FATAL) << "[" << name_ << "] AdvanceFromPrediction not supported";
} }
void AdvanceFromOracle() override { void AdvanceFromOracle() override {
LOG(FATAL) << "[" << name_ << "] Method not supported"; LOG(FATAL) << "[" << name_ << "] AdvanceFromOracle not supported";
} }
std::vector<std::vector<int>> GetOracleLabels() const override { std::vector<std::vector<int>> GetOracleLabels() const override {
LOG(FATAL) << "[" << name_ << "] Method not supported"; LOG(FATAL) << "[" << name_ << "] Method not supported";
return {};
} }
int GetFixedFeatures(std::function<int32 *(int)> allocate_indices, int GetFixedFeatures(std::function<int32 *(int)> allocate_indices,
std::function<int64 *(int)> allocate_ids, std::function<int64 *(int)> allocate_ids,
std::function<float *(int)> allocate_weights, std::function<float *(int)> allocate_weights,
int channel_id) const override { int channel_id) const override {
LOG(FATAL) << "[" << name_ << "] Method not supported"; LOG(FATAL) << "[" << name_ << "] Method not supported";
return 0;
} }
int BulkGetFixedFeatures(const BulkFeatureExtractor &extractor) override { int BulkGetFixedFeatures(const BulkFeatureExtractor &extractor) override {
LOG(FATAL) << "[" << name_ << "] Method not supported"; LOG(FATAL) << "[" << name_ << "] Method not supported";
return 0;
} }
void BulkEmbedFixedFeatures(
int batch_size_padding, int num_steps_padding, int output_array_size,
const vector<const float *> &per_channel_embeddings,
float *embedding_output) override {
LOG(FATAL) << "[" << name_ << "] Method not supported";
}
std::vector<LinkFeatures> GetRawLinkFeatures(int channel_id) const override { std::vector<LinkFeatures> GetRawLinkFeatures(int channel_id) const override {
LOG(FATAL) << "[" << name_ << "] Method not supported"; LOG(FATAL) << "[" << name_ << "] Method not supported";
return {};
} }
void AddTranslatedLinkFeaturesToTrace( void AddTranslatedLinkFeaturesToTrace(
const std::vector<LinkFeatures> &features, int channel_id) override { const std::vector<LinkFeatures> &features, int channel_id) override {
......
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
#include "dragnn/core/test/generic.h" #include "dragnn/core/test/generic.h"
#include "dragnn/core/test/mock_transition_state.h" #include "dragnn/core/test/mock_transition_state.h"
#include "dragnn/io/sentence_input_batch.h" #include "dragnn/io/sentence_input_batch.h"
#include "dragnn/protos/data.pb.h"
#include "syntaxnet/base.h" #include "syntaxnet/base.h"
#include "syntaxnet/sentence.pb.h" #include "syntaxnet/sentence.pb.h"
#include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/errors.h"
...@@ -119,6 +120,10 @@ class StatelessComponentTest : public ::testing::Test { ...@@ -119,6 +120,10 @@ class StatelessComponentTest : public ::testing::Test {
CHECK(TextFormat::ParseFromString(kMasterSpec, &master_spec)); CHECK(TextFormat::ParseFromString(kMasterSpec, &master_spec));
data_.reset(new InputBatchCache(data)); data_.reset(new InputBatchCache(data));
// The stateless component does not use any particular input batch type, and
// relies on the preceding components to convert the input batch.
data_->GetAs<SentenceInputBatch>();
// Create a parser component with the specified beam size. // Create a parser component with the specified beam size.
std::unique_ptr<Component> parser_component( std::unique_ptr<Component> parser_component(
Component::Create("StatelessComponent")); Component::Create("StatelessComponent"));
...@@ -167,5 +172,37 @@ TEST_F(StatelessComponentTest, ForwardsTransitionStates) { ...@@ -167,5 +172,37 @@ TEST_F(StatelessComponentTest, ForwardsTransitionStates) {
EXPECT_EQ(parent_states, forwarded_states); EXPECT_EQ(parent_states, forwarded_states);
} }
TEST_F(StatelessComponentTest, UnimplementedMethodsDie) {
MockTransitionState mock_state_1, mock_state_2, mock_state_3;
const std::vector<std::vector<const TransitionState *>> parent_states;
std::vector<string> data;
for (const string &textproto : {kSentence0, kSentence1, kLongSentence}) {
Sentence sentence;
CHECK(TextFormat::ParseFromString(textproto, &sentence));
data.emplace_back();
CHECK(sentence.SerializeToString(&data.back()));
}
const int kBeamSize = 2;
auto test_parser = CreateParser(kBeamSize, parent_states, data);
EXPECT_TRUE(test_parser->IsReady());
EXPECT_DEATH(test_parser->AdvanceFromPrediction({}, 0, 0),
"AdvanceFromPrediction not supported");
EXPECT_DEATH(test_parser->AdvanceFromOracle(),
"AdvanceFromOracle not supported");
EXPECT_DEATH(test_parser->GetOracleLabels(), "Method not supported");
EXPECT_DEATH(test_parser->GetFixedFeatures(nullptr, nullptr, nullptr, 0),
"Method not supported");
BulkFeatureExtractor extractor(nullptr, nullptr, nullptr);
EXPECT_DEATH(test_parser->BulkEmbedFixedFeatures(0, 0, 0, {nullptr}, nullptr),
"Method not supported");
EXPECT_DEATH(test_parser->BulkGetFixedFeatures(extractor),
"Method not supported");
EXPECT_DEATH(test_parser->GetRawLinkFeatures(0), "Method not supported");
EXPECT_DEATH(test_parser->AddTranslatedLinkFeaturesToTrace({}, 0),
"Method not supported");
}
} // namespace dragnn } // namespace dragnn
} // namespace syntaxnet } // namespace syntaxnet
...@@ -28,6 +28,7 @@ ...@@ -28,6 +28,7 @@
#include "syntaxnet/sparse.pb.h" #include "syntaxnet/sparse.pb.h"
#include "syntaxnet/task_spec.pb.h" #include "syntaxnet/task_spec.pb.h"
#include "syntaxnet/utils.h" #include "syntaxnet/utils.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/logging.h"
namespace syntaxnet { namespace syntaxnet {
...@@ -105,7 +106,7 @@ void SyntaxNetComponent::InitializeComponent(const ComponentSpec &spec) { ...@@ -105,7 +106,7 @@ void SyntaxNetComponent::InitializeComponent(const ComponentSpec &spec) {
dims.push_back(StrCat(channel.embedding_dim())); dims.push_back(StrCat(channel.embedding_dim()));
} }
context.SetParameter("neurosis_feature_syntax_version", "2");
context.SetParameter("brain_parser_embedding_dims", utils::Join(dims, ";")); context.SetParameter("brain_parser_embedding_dims", utils::Join(dims, ";"));
context.SetParameter("brain_parser_predicate_maps", context.SetParameter("brain_parser_predicate_maps",
utils::Join(predicate_maps, ";")); utils::Join(predicate_maps, ";"));
...@@ -187,8 +188,9 @@ std::unique_ptr<Beam<SyntaxNetTransitionState>> SyntaxNetComponent::CreateBeam( ...@@ -187,8 +188,9 @@ std::unique_ptr<Beam<SyntaxNetTransitionState>> SyntaxNetComponent::CreateBeam(
return this->IsFinal(state); return this->IsFinal(state);
}; };
auto oracle_function = [this](SyntaxNetTransitionState *state) { auto oracle_function = [this](SyntaxNetTransitionState *state) {
VLOG(2) << "oracle_function action:" << this->GetOracleLabel(state); VLOG(2) << "oracle_function action:"
return this->GetOracleLabel(state); << tensorflow::str_util::Join(this->GetOracleVector(state), ", ");
return this->GetOracleVector(state);
}; };
auto beam_ptr = beam.get(); auto beam_ptr = beam.get();
auto advance_function = [this, beam_ptr](SyntaxNetTransitionState *state, auto advance_function = [this, beam_ptr](SyntaxNetTransitionState *state,
...@@ -335,25 +337,32 @@ std::function<int(int, int, int)> SyntaxNetComponent::GetStepLookupFunction( ...@@ -335,25 +337,32 @@ std::function<int(int, int, int)> SyntaxNetComponent::GetStepLookupFunction(
} }
} }
void SyntaxNetComponent::AdvanceFromPrediction(const float transition_matrix[], bool SyntaxNetComponent::AdvanceFromPrediction(const float *transition_matrix,
int transition_matrix_length) { int num_items, int num_actions) {
VLOG(2) << "Advancing from prediction."; VLOG(2) << "Advancing from prediction, component = " << spec_.name();
int matrix_index = 0; const int num_static_actions =
int num_labels = transition_system_->NumActions(label_map_->Size()); transition_system_->NumActions(label_map_->Size());
if (num_static_actions != ParserTransitionSystem::kDynamicNumActions) {
CHECK_EQ(num_static_actions, num_actions)
<< "[" << spec_.name()
<< "] static action set does not match transition matrix";
}
for (int i = 0; i < batch_.size(); ++i) { for (int i = 0; i < batch_.size(); ++i) {
int max_beam_size = batch_.at(i)->max_size(); const int size = num_actions * batch_[i]->max_size();
int matrix_size = num_labels * max_beam_size; if (!batch_[i]->IsTerminal()) {
CHECK_LE(matrix_index + matrix_size, transition_matrix_length); bool success = batch_[i]->AdvanceFromPrediction(transition_matrix, size,
if (!batch_.at(i)->IsTerminal()) { num_actions);
batch_.at(i)->AdvanceFromPrediction(&transition_matrix[matrix_index], if (!success) {
matrix_size, num_labels); return false;
}
} }
matrix_index += num_labels * max_beam_size; transition_matrix += size;
} }
return true;
} }
void SyntaxNetComponent::AdvanceFromOracle() { void SyntaxNetComponent::AdvanceFromOracle() {
VLOG(2) << "Advancing from oracle."; VLOG(2) << "Advancing from oracle, component = " << spec_.name();
for (auto &beam : batch_) { for (auto &beam : batch_) {
beam->AdvanceFromOracle(); beam->AdvanceFromOracle();
} }
...@@ -404,8 +413,18 @@ int SyntaxNetComponent::GetFixedFeatures( ...@@ -404,8 +413,18 @@ int SyntaxNetComponent::GetFixedFeatures(
features.emplace_back(f); features.emplace_back(f);
if (do_tracing_) { if (do_tracing_) {
FixedFeatures fixed_features; FixedFeatures fixed_features;
for (const string &name : f.description()) { CHECK_EQ(f.description_size(), f.id_size());
fixed_features.add_value_name(name); CHECK(f.weight_size() == 0 || f.weight_size() == f.id_size());
const bool has_weights = f.weight_size() != 0;
for (int i = 0; i < f.description_size(); ++i) {
if (has_weights) {
fixed_features.add_value_name(StrCat("id: ", f.id(i),
" name: ", f.description(i),
" weight: ", f.weight(i)));
} else {
fixed_features.add_value_name(
StrCat("id: ", f.id(i), " name: ", f.description(i)));
}
} }
fixed_features.set_feature_name(""); fixed_features.set_feature_name("");
auto *trace = GetLastStepInTrace(state->mutable_trace()); auto *trace = GetLastStepInTrace(state->mutable_trace());
...@@ -522,8 +541,8 @@ int SyntaxNetComponent::BulkGetFixedFeatures( ...@@ -522,8 +541,8 @@ int SyntaxNetComponent::BulkGetFixedFeatures(
// This would be a good place to add threading. // This would be a good place to add threading.
for (int channel_id = 0; channel_id < num_channels; ++channel_id) { for (int channel_id = 0; channel_id < num_channels; ++channel_id) {
int feature_count = feature_counts[channel_id]; int feature_count = feature_counts[channel_id];
LOG(INFO) << "Feature count is " << feature_count << " for channel " VLOG(2) << "Feature count is " << feature_count << " for channel "
<< channel_id; << channel_id;
int32 *indices_tensor = int32 *indices_tensor =
extractor.AllocateIndexMemory(channel_id, feature_count); extractor.AllocateIndexMemory(channel_id, feature_count);
int64 *ids_tensor = extractor.AllocateIdMemory(channel_id, feature_count); int64 *ids_tensor = extractor.AllocateIdMemory(channel_id, feature_count);
...@@ -603,7 +622,9 @@ std::vector<std::vector<int>> SyntaxNetComponent::GetOracleLabels() const { ...@@ -603,7 +622,9 @@ std::vector<std::vector<int>> SyntaxNetComponent::GetOracleLabels() const {
for (int beam_idx = 0; beam_idx < beam->size(); ++beam_idx) { for (int beam_idx = 0; beam_idx < beam->size(); ++beam_idx) {
// Get the raw link features from the linked feature extractor. // Get the raw link features from the linked feature extractor.
auto state = beam->beam_state(beam_idx); auto state = beam->beam_state(beam_idx);
oracle_labels.back().push_back(GetOracleLabel(state));
// Arbitrarily choose the first vector element.
oracle_labels.back().push_back(GetOracleVector(state).front());
} }
} }
return oracle_labels; return oracle_labels;
...@@ -661,13 +682,17 @@ bool SyntaxNetComponent::IsFinal(SyntaxNetTransitionState *state) const { ...@@ -661,13 +682,17 @@ bool SyntaxNetComponent::IsFinal(SyntaxNetTransitionState *state) const {
return transition_system_->IsFinalState(*(state->parser_state())); return transition_system_->IsFinalState(*(state->parser_state()));
} }
int SyntaxNetComponent::GetOracleLabel(SyntaxNetTransitionState *state) const { std::vector<int> SyntaxNetComponent::GetOracleVector(
SyntaxNetTransitionState *state) const {
if (IsFinal(state)) { if (IsFinal(state)) {
// It is not permitted to request an oracle label from a sentence that is // It is not permitted to request an oracle label from a sentence that is
// in a final state. // in a final state.
return -1; return {-1};
} else { } else {
return transition_system_->GetNextGoldAction(*(state->parser_state())); // TODO(googleuser): This should use the 'ParserAction' typedef.
std::vector<int> golds;
transition_system_->GetAllNextGoldActions(*(state->parser_state()), &golds);
return golds;
} }
} }
......
...@@ -13,8 +13,8 @@ ...@@ -13,8 +13,8 @@
// limitations under the License. // limitations under the License.
// ============================================================================= // =============================================================================
#ifndef NLP_SAFT_OPENSOURCE_DRAGNN_COMPONENTS_SYNTAXNET_SYNTAXNET_COMPONENT_H_ #ifndef DRAGNN_COMPONENTS_SYNTAXNET_SYNTAXNET_COMPONENT_H_
#define NLP_SAFT_OPENSOURCE_DRAGNN_COMPONENTS_SYNTAXNET_SYNTAXNET_COMPONENT_H_ #define DRAGNN_COMPONENTS_SYNTAXNET_SYNTAXNET_COMPONENT_H_
#include <vector> #include <vector>
...@@ -81,9 +81,10 @@ class SyntaxNetComponent : public Component { ...@@ -81,9 +81,10 @@ class SyntaxNetComponent : public Component {
std::function<int(int, int, int)> GetStepLookupFunction( std::function<int(int, int, int)> GetStepLookupFunction(
const string &method) override; const string &method) override;
// Advances this component from the given transition matrix. // Advances this component from the given transition matrix.Returns false
void AdvanceFromPrediction(const float transition_matrix[], // if the component could not be advanced.
int transition_matrix_length) override; bool AdvanceFromPrediction(const float *transition_matrix, int num_items,
int num_actions) override;
// Advances this component from the state oracles. // Advances this component from the state oracles.
void AdvanceFromOracle() override; void AdvanceFromOracle() override;
...@@ -105,6 +106,13 @@ class SyntaxNetComponent : public Component { ...@@ -105,6 +106,13 @@ class SyntaxNetComponent : public Component {
// component via the oracle until it is terminal. // component via the oracle until it is terminal.
int BulkGetFixedFeatures(const BulkFeatureExtractor &extractor) override; int BulkGetFixedFeatures(const BulkFeatureExtractor &extractor) override;
void BulkEmbedFixedFeatures(
int batch_size_padding, int num_steps_padding, int output_array_size,
const vector<const float *> &per_channel_embeddings,
float *embedding_matrix) override {
LOG(FATAL) << "Method not supported";
}
// Extracts and returns the vector of LinkFeatures for the specified // Extracts and returns the vector of LinkFeatures for the specified
// channel. Note: these are NOT translated. // channel. Note: these are NOT translated.
std::vector<LinkFeatures> GetRawLinkFeatures(int channel_id) const override; std::vector<LinkFeatures> GetRawLinkFeatures(int channel_id) const override;
...@@ -145,13 +153,13 @@ class SyntaxNetComponent : public Component { ...@@ -145,13 +153,13 @@ class SyntaxNetComponent : public Component {
bool IsFinal(SyntaxNetTransitionState *state) const; bool IsFinal(SyntaxNetTransitionState *state) const;
// Oracle function for this component. // Oracle function for this component.
int GetOracleLabel(SyntaxNetTransitionState *state) const; std::vector<int> GetOracleVector(SyntaxNetTransitionState *state) const;
// State advance function for this component. // State advance function for this component.
void Advance(SyntaxNetTransitionState *state, int action, void Advance(SyntaxNetTransitionState *state, int action,
Beam<SyntaxNetTransitionState> *beam); Beam<SyntaxNetTransitionState> *beam);
// Creates a new state for the given nlp_saft::SentenceExample. // Creates a new state for the given example.
std::unique_ptr<SyntaxNetTransitionState> CreateState( std::unique_ptr<SyntaxNetTransitionState> CreateState(
SyntaxNetSentence *example); SyntaxNetSentence *example);
...@@ -195,4 +203,4 @@ class SyntaxNetComponent : public Component { ...@@ -195,4 +203,4 @@ class SyntaxNetComponent : public Component {
} // namespace dragnn } // namespace dragnn
} // namespace syntaxnet } // namespace syntaxnet
#endif // NLP_SAFT_OPENSOURCE_DRAGNN_COMPONENTS_SYNTAXNET_SYNTAXNET_COMPONENT_H_ #endif // DRAGNN_COMPONENTS_SYNTAXNET_SYNTAXNET_COMPONENT_H_
...@@ -15,6 +15,8 @@ ...@@ -15,6 +15,8 @@
#include "dragnn/components/syntaxnet/syntaxnet_component.h" #include "dragnn/components/syntaxnet/syntaxnet_component.h"
#include <limits>
#include "dragnn/core/input_batch_cache.h" #include "dragnn/core/input_batch_cache.h"
#include "dragnn/core/test/generic.h" #include "dragnn/core/test/generic.h"
#include "dragnn/core/test/mock_transition_state.h" #include "dragnn/core/test/mock_transition_state.h"
...@@ -197,8 +199,8 @@ TEST_F(SyntaxNetComponentTest, AdvancesFromPredictionAndTerminates) { ...@@ -197,8 +199,8 @@ TEST_F(SyntaxNetComponentTest, AdvancesFromPredictionAndTerminates) {
// Transition the expected number of times. // Transition the expected number of times.
for (int i = 0; i < kExpectedNumTransitions; ++i) { for (int i = 0; i < kExpectedNumTransitions; ++i) {
EXPECT_FALSE(test_parser->IsTerminal()); EXPECT_FALSE(test_parser->IsTerminal());
test_parser->AdvanceFromPrediction(transition_matrix, EXPECT_TRUE(test_parser->AdvanceFromPrediction(transition_matrix, kBeamSize,
kNumPossibleTransitions * kBeamSize); kNumPossibleTransitions));
} }
// At this point, the test parser should be terminal. // At this point, the test parser should be terminal.
...@@ -225,6 +227,29 @@ TEST_F(SyntaxNetComponentTest, AdvancesFromPredictionAndTerminates) { ...@@ -225,6 +227,29 @@ TEST_F(SyntaxNetComponentTest, AdvancesFromPredictionAndTerminates) {
// TODO(googleuser): What should the finalized data look like? // TODO(googleuser): What should the finalized data look like?
} }
TEST_F(SyntaxNetComponentTest, AdvancesFromPredictionFailsWithNanWeights) {
// Create an empty input batch and beam vector to initialize the parser.
Sentence sentence_0;
TextFormat::ParseFromString(kSentence0, &sentence_0);
string sentence_0_str;
sentence_0.SerializeToString(&sentence_0_str);
auto test_parser = CreateParser({}, {sentence_0_str});
// There are 93 possible transitions for any given state. Create a transition
// array with a score of 10.0 for each transition.
constexpr int kBeamSize = 2;
constexpr int kNumPossibleTransitions = 93;
float transition_matrix[kNumPossibleTransitions * kBeamSize];
for (int i = 0; i < kNumPossibleTransitions * kBeamSize; ++i) {
transition_matrix[i] = std::numeric_limits<double>::quiet_NaN();
}
EXPECT_FALSE(test_parser->IsTerminal());
EXPECT_FALSE(test_parser->AdvanceFromPrediction(transition_matrix, kBeamSize,
kNumPossibleTransitions));
}
TEST_F(SyntaxNetComponentTest, RetainsPassedTransitionStateData) { TEST_F(SyntaxNetComponentTest, RetainsPassedTransitionStateData) {
// Create and initialize the state-> // Create and initialize the state->
MockTransitionState mock_state_one; MockTransitionState mock_state_one;
...@@ -269,8 +294,8 @@ TEST_F(SyntaxNetComponentTest, RetainsPassedTransitionStateData) { ...@@ -269,8 +294,8 @@ TEST_F(SyntaxNetComponentTest, RetainsPassedTransitionStateData) {
// Transition the expected number of times // Transition the expected number of times
for (int i = 0; i < kExpectedNumTransitions; ++i) { for (int i = 0; i < kExpectedNumTransitions; ++i) {
EXPECT_FALSE(test_parser->IsTerminal()); EXPECT_FALSE(test_parser->IsTerminal());
test_parser->AdvanceFromPrediction(transition_matrix, EXPECT_TRUE(test_parser->AdvanceFromPrediction(transition_matrix, kBeamSize,
kNumPossibleTransitions * kBeamSize); kNumPossibleTransitions));
} }
// At this point, the test parser should be terminal. // At this point, the test parser should be terminal.
...@@ -326,8 +351,8 @@ TEST_F(SyntaxNetComponentTest, AdvancesFromPredictionForMultiSentenceBatches) { ...@@ -326,8 +351,8 @@ TEST_F(SyntaxNetComponentTest, AdvancesFromPredictionForMultiSentenceBatches) {
// Transition the expected number of times. // Transition the expected number of times.
for (int i = 0; i < kExpectedNumTransitions; ++i) { for (int i = 0; i < kExpectedNumTransitions; ++i) {
EXPECT_FALSE(test_parser->IsTerminal()); EXPECT_FALSE(test_parser->IsTerminal());
test_parser->AdvanceFromPrediction( EXPECT_TRUE(test_parser->AdvanceFromPrediction(
transition_matrix, kNumPossibleTransitions * kBeamSize * kBatchSize); transition_matrix, kBeamSize * kBatchSize, kNumPossibleTransitions));
} }
// At this point, the test parser should be terminal. // At this point, the test parser should be terminal.
...@@ -382,8 +407,8 @@ TEST_F(SyntaxNetComponentTest, ...@@ -382,8 +407,8 @@ TEST_F(SyntaxNetComponentTest,
constexpr int kExpectedNumTransitions = kNumTokensInLongSentence * 2; constexpr int kExpectedNumTransitions = kNumTokensInLongSentence * 2;
for (int i = 0; i < kExpectedNumTransitions; ++i) { for (int i = 0; i < kExpectedNumTransitions; ++i) {
EXPECT_FALSE(test_parser->IsTerminal()); EXPECT_FALSE(test_parser->IsTerminal());
test_parser->AdvanceFromPrediction( EXPECT_TRUE(test_parser->AdvanceFromPrediction(
transition_matrix, kNumPossibleTransitions * kBeamSize * kBatchSize); transition_matrix, kBeamSize * kBatchSize, kNumPossibleTransitions));
} }
// At this point, the test parser should be terminal. // At this point, the test parser should be terminal.
...@@ -467,7 +492,7 @@ TEST_F(SyntaxNetComponentTest, ResetAllowsReductionInBatchSize) { ...@@ -467,7 +492,7 @@ TEST_F(SyntaxNetComponentTest, ResetAllowsReductionInBatchSize) {
for (int i = 0; i < kExpectedNumTransitions; ++i) { for (int i = 0; i < kExpectedNumTransitions; ++i) {
EXPECT_FALSE(parser_component->IsTerminal()); EXPECT_FALSE(parser_component->IsTerminal());
parser_component->AdvanceFromPrediction( parser_component->AdvanceFromPrediction(
transition_matrix, kNumPossibleTransitions * kBeamSize * kBatchSize); transition_matrix, kBeamSize * kBatchSize, kNumPossibleTransitions);
} }
// At this point, the test parser should be terminal. // At this point, the test parser should be terminal.
...@@ -553,7 +578,7 @@ TEST_F(SyntaxNetComponentTest, ResetAllowsIncreaseInBatchSize) { ...@@ -553,7 +578,7 @@ TEST_F(SyntaxNetComponentTest, ResetAllowsIncreaseInBatchSize) {
for (int i = 0; i < kExpectedNumTransitions; ++i) { for (int i = 0; i < kExpectedNumTransitions; ++i) {
EXPECT_FALSE(parser_component->IsTerminal()); EXPECT_FALSE(parser_component->IsTerminal());
parser_component->AdvanceFromPrediction( parser_component->AdvanceFromPrediction(
transition_matrix, kNumPossibleTransitions * kBeamSize * kBatchSize); transition_matrix, kBeamSize * kBatchSize, kNumPossibleTransitions);
} }
// At this point, the test parser should be terminal. // At this point, the test parser should be terminal.
...@@ -611,8 +636,8 @@ TEST_F(SyntaxNetComponentTest, ResetCausesBeamToReset) { ...@@ -611,8 +636,8 @@ TEST_F(SyntaxNetComponentTest, ResetCausesBeamToReset) {
// Transition the expected number of times. // Transition the expected number of times.
for (int i = 0; i < kExpectedNumTransitions; ++i) { for (int i = 0; i < kExpectedNumTransitions; ++i) {
EXPECT_FALSE(test_parser->IsTerminal()); EXPECT_FALSE(test_parser->IsTerminal());
test_parser->AdvanceFromPrediction(transition_matrix, EXPECT_TRUE(test_parser->AdvanceFromPrediction(transition_matrix, kBeamSize,
kNumPossibleTransitions * kBeamSize); kNumPossibleTransitions));
} }
// At this point, the test parser should be terminal. // At this point, the test parser should be terminal.
...@@ -823,10 +848,10 @@ TEST_F(SyntaxNetComponentTest, ExportsFixedFeatures) { ...@@ -823,10 +848,10 @@ TEST_F(SyntaxNetComponentTest, ExportsFixedFeatures) {
} }
// Advance twice, so that the underlying parser fills the beam. // Advance twice, so that the underlying parser fills the beam.
test_parser->AdvanceFromPrediction( EXPECT_TRUE(test_parser->AdvanceFromPrediction(
transition_matrix, kNumPossibleTransitions * kBeamSize * kBatchSize); transition_matrix, kBeamSize * kBatchSize, kNumPossibleTransitions));
test_parser->AdvanceFromPrediction( EXPECT_TRUE(test_parser->AdvanceFromPrediction(
transition_matrix, kNumPossibleTransitions * kBeamSize * kBatchSize); transition_matrix, kBeamSize * kBatchSize, kNumPossibleTransitions));
// Get and check the raw link features. // Get and check the raw link features.
vector<int32> indices; vector<int32> indices;
...@@ -907,10 +932,10 @@ TEST_F(SyntaxNetComponentTest, AdvancesAccordingToHighestWeightedInputOption) { ...@@ -907,10 +932,10 @@ TEST_F(SyntaxNetComponentTest, AdvancesAccordingToHighestWeightedInputOption) {
transition_matrix[kBatchOffset + 5] = 2 * kTransitionValue; transition_matrix[kBatchOffset + 5] = 2 * kTransitionValue;
// Advance twice, so that the underlying parser fills the beam. // Advance twice, so that the underlying parser fills the beam.
test_parser->AdvanceFromPrediction( EXPECT_TRUE(test_parser->AdvanceFromPrediction(
transition_matrix, kNumPossibleTransitions * kBeamSize * kBatchSize); transition_matrix, kBeamSize * kBatchSize, kNumPossibleTransitions));
test_parser->AdvanceFromPrediction( EXPECT_TRUE(test_parser->AdvanceFromPrediction(
transition_matrix, kNumPossibleTransitions * kBeamSize * kBatchSize); transition_matrix, kBeamSize * kBatchSize, kNumPossibleTransitions));
// Get and check the raw link features. // Get and check the raw link features.
vector<int32> indices; vector<int32> indices;
...@@ -1112,10 +1137,10 @@ TEST_F(SyntaxNetComponentTest, ExportsRawLinkFeatures) { ...@@ -1112,10 +1137,10 @@ TEST_F(SyntaxNetComponentTest, ExportsRawLinkFeatures) {
} }
// Advance twice, so that the underlying parser fills the beam. // Advance twice, so that the underlying parser fills the beam.
test_parser->AdvanceFromPrediction( EXPECT_TRUE(test_parser->AdvanceFromPrediction(
transition_matrix, kNumPossibleTransitions * kBeamSize * kBatchSize); transition_matrix, kBeamSize * kBatchSize, kNumPossibleTransitions));
test_parser->AdvanceFromPrediction( EXPECT_TRUE(test_parser->AdvanceFromPrediction(
transition_matrix, kNumPossibleTransitions * kBeamSize * kBatchSize); transition_matrix, kBeamSize * kBatchSize, kNumPossibleTransitions));
// Get and check the raw link features. // Get and check the raw link features.
constexpr int kNumLinkFeatures = 2; constexpr int kNumLinkFeatures = 2;
...@@ -1269,5 +1294,21 @@ TEST_F(SyntaxNetComponentTest, TracingOutputsFeatureNames) { ...@@ -1269,5 +1294,21 @@ TEST_F(SyntaxNetComponentTest, TracingOutputsFeatureNames) {
EXPECT_EQ(link_features.at(1).feature_name(), "stack(1).focus"); EXPECT_EQ(link_features.at(1).feature_name(), "stack(1).focus");
} }
TEST_F(SyntaxNetComponentTest, BulkEmbedFixedFeaturesIsNotSupported) {
// Create an empty input batch and beam vector to initialize the parser.
Sentence sentence_0;
// TODO(googleuser): Wrap this in a lint-friendly helper function.
TextFormat::ParseFromString(kSentence0, &sentence_0);
string sentence_0_str;
sentence_0.SerializeToString(&sentence_0_str);
constexpr int kBeamSize = 1;
auto test_parser = CreateParserWithBeamSize(kBeamSize, {}, {sentence_0_str});
EXPECT_TRUE(test_parser->IsReady());
EXPECT_DEATH(test_parser->BulkEmbedFixedFeatures(0, 0, 0, {nullptr}, nullptr),
"Method not supported");
}
} // namespace dragnn } // namespace dragnn
} // namespace syntaxnet } // namespace syntaxnet
...@@ -13,8 +13,8 @@ ...@@ -13,8 +13,8 @@
// limitations under the License. // limitations under the License.
// ============================================================================= // =============================================================================
#ifndef NLP_SAFT_OPENSOURCE_DRAGNN_COMPONENTS_SYNTAXNET_SYNTAXNET_LINK_FEATURE_EXTRACTOR_H_ #ifndef DRAGNN_COMPONENTS_SYNTAXNET_SYNTAXNET_LINK_FEATURE_EXTRACTOR_H_
#define NLP_SAFT_OPENSOURCE_DRAGNN_COMPONENTS_SYNTAXNET_SYNTAXNET_LINK_FEATURE_EXTRACTOR_H_ #define DRAGNN_COMPONENTS_SYNTAXNET_SYNTAXNET_LINK_FEATURE_EXTRACTOR_H_
#include <string> #include <string>
#include <vector> #include <vector>
...@@ -29,12 +29,8 @@ namespace syntaxnet { ...@@ -29,12 +29,8 @@ namespace syntaxnet {
namespace dragnn { namespace dragnn {
// Provides feature extraction for linked features in the // Provides feature extraction for linked features in the
// WrapperParserComponent. This re-ues the EmbeddingFeatureExtractor // WrapperParserComponent. This re-uses the EmbeddingFeatureExtractor
// architecture to get another set of feature extractors. Note that we should // architecture to get another set of feature extractors.
// ignore predicate maps here, and we don't care about the vocabulary size
// because all the feature values will be used for translation, but this means
// we can configure the extractor from the GCL using the standard
// neurosis-lib.wf syntax.
// //
// Because it uses a different prefix, it can be executed in the same wf.stage // Because it uses a different prefix, it can be executed in the same wf.stage
// as the regular fixed extractor. // as the regular fixed extractor.
...@@ -67,4 +63,4 @@ class SyntaxNetLinkFeatureExtractor : public ParserEmbeddingFeatureExtractor { ...@@ -67,4 +63,4 @@ class SyntaxNetLinkFeatureExtractor : public ParserEmbeddingFeatureExtractor {
} // namespace dragnn } // namespace dragnn
} // namespace syntaxnet } // namespace syntaxnet
#endif // NLP_SAFT_OPENSOURCE_DRAGNN_COMPONENTS_SYNTAXNET_SYNTAXNET_LINK_FEATURE_EXTRACTOR_H_ #endif // DRAGNN_COMPONENTS_SYNTAXNET_SYNTAXNET_LINK_FEATURE_EXTRACTOR_H_
...@@ -34,7 +34,7 @@ class ExportSpecTest : public ::testing::Test { ...@@ -34,7 +34,7 @@ class ExportSpecTest : public ::testing::Test {
TEST_F(ExportSpecTest, WritesChannelSpec) { TEST_F(ExportSpecTest, WritesChannelSpec) {
TaskContext context; TaskContext context;
context.SetParameter("neurosis_feature_syntax_version", "2");
context.SetParameter("link_features", "input.focus;stack.focus"); context.SetParameter("link_features", "input.focus;stack.focus");
context.SetParameter("link_embedding_names", "tagger;parser"); context.SetParameter("link_embedding_names", "tagger;parser");
context.SetParameter("link_predicate_maps", "none;none"); context.SetParameter("link_predicate_maps", "none;none");
......
...@@ -23,7 +23,9 @@ namespace dragnn { ...@@ -23,7 +23,9 @@ namespace dragnn {
SyntaxNetTransitionState::SyntaxNetTransitionState( SyntaxNetTransitionState::SyntaxNetTransitionState(
std::unique_ptr<ParserState> parser_state, SyntaxNetSentence *sentence) std::unique_ptr<ParserState> parser_state, SyntaxNetSentence *sentence)
: parser_state_(std::move(parser_state)), sentence_(sentence) { : parser_state_(std::move(parser_state)),
sentence_(sentence),
is_gold_(false) {
score_ = 0; score_ = 0;
current_beam_index_ = -1; current_beam_index_ = -1;
parent_beam_index_ = 0; parent_beam_index_ = 0;
...@@ -60,21 +62,25 @@ std::unique_ptr<SyntaxNetTransitionState> SyntaxNetTransitionState::Clone() ...@@ -60,21 +62,25 @@ std::unique_ptr<SyntaxNetTransitionState> SyntaxNetTransitionState::Clone()
return new_state; return new_state;
} }
const int SyntaxNetTransitionState::ParentBeamIndex() const { int SyntaxNetTransitionState::ParentBeamIndex() const {
return parent_beam_index_; return parent_beam_index_;
} }
const int SyntaxNetTransitionState::GetBeamIndex() const { int SyntaxNetTransitionState::GetBeamIndex() const {
return current_beam_index_; return current_beam_index_;
} }
void SyntaxNetTransitionState::SetBeamIndex(const int index) { bool SyntaxNetTransitionState::IsGold() const { return is_gold_; }
void SyntaxNetTransitionState::SetGold(bool is_gold) { is_gold_ = is_gold; }
void SyntaxNetTransitionState::SetBeamIndex(int index) {
current_beam_index_ = index; current_beam_index_ = index;
} }
const float SyntaxNetTransitionState::GetScore() const { return score_; } float SyntaxNetTransitionState::GetScore() const { return score_; }
void SyntaxNetTransitionState::SetScore(const float score) { score_ = score; } void SyntaxNetTransitionState::SetScore(float score) { score_ = score; }
string SyntaxNetTransitionState::HTMLRepresentation() const { string SyntaxNetTransitionState::HTMLRepresentation() const {
// Crude HTML string showing the stack and the word on the input. // Crude HTML string showing the stack and the word on the input.
......
...@@ -13,8 +13,8 @@ ...@@ -13,8 +13,8 @@
// limitations under the License. // limitations under the License.
// ============================================================================= // =============================================================================
#ifndef NLP_SAFT_OPENSOURCE_DRAGNN_COMPONENTS_SYNTAXNET_SYNTAXNET_TRANSITION_STATE_H_ #ifndef DRAGNN_COMPONENTS_SYNTAXNET_SYNTAXNET_TRANSITION_STATE_H_
#define NLP_SAFT_OPENSOURCE_DRAGNN_COMPONENTS_SYNTAXNET_SYNTAXNET_TRANSITION_STATE_H_ #define DRAGNN_COMPONENTS_SYNTAXNET_SYNTAXNET_TRANSITION_STATE_H_
#include <vector> #include <vector>
...@@ -31,11 +31,11 @@ namespace dragnn { ...@@ -31,11 +31,11 @@ namespace dragnn {
class SyntaxNetTransitionState class SyntaxNetTransitionState
: public CloneableTransitionState<SyntaxNetTransitionState> { : public CloneableTransitionState<SyntaxNetTransitionState> {
public: public:
// Create a SyntaxNetTransitionState to wrap this nlp_saft::ParserState. // Creates a SyntaxNetTransitionState to wrap this ParserState.
SyntaxNetTransitionState(std::unique_ptr<ParserState> parser_state, SyntaxNetTransitionState(std::unique_ptr<ParserState> parser_state,
SyntaxNetSentence *sentence); SyntaxNetSentence *sentence);
// Initialize this TransitionState from a previous TransitionState. The // Initializes this TransitionState from a previous TransitionState. The
// ParentBeamIndex is the location of that previous TransitionState in the // ParentBeamIndex is the location of that previous TransitionState in the
// provided beam. // provided beam.
void Init(const TransitionState &parent) override; void Init(const TransitionState &parent) override;
...@@ -43,21 +43,27 @@ class SyntaxNetTransitionState ...@@ -43,21 +43,27 @@ class SyntaxNetTransitionState
// Produces a new state with the same backing data as this state. // Produces a new state with the same backing data as this state.
std::unique_ptr<SyntaxNetTransitionState> Clone() const override; std::unique_ptr<SyntaxNetTransitionState> Clone() const override;
// Return the beam index of the state passed into the initializer of this // Returns the beam index of the state passed into the initializer of this
// TransitionState. // TransitionState.
const int ParentBeamIndex() const override; int ParentBeamIndex() const override;
// Get the current beam index for this state. // Gets the current beam index for this state.
const int GetBeamIndex() const override; int GetBeamIndex() const override;
// Set the current beam index for this state. // Sets the current beam index for this state.
void SetBeamIndex(const int index) override; void SetBeamIndex(int index) override;
// Get the score associated with this transition state. // Gets the score associated with this transition state.
const float GetScore() const override; float GetScore() const override;
// Set the score associated with this transition state. // Sets the score associated with this transition state.
void SetScore(const float score) override; void SetScore(float score) override;
// Gets the state's gold-ness (if it is on or consistent with the oracle path)
bool IsGold() const override;
// Sets the gold-ness of this state.
void SetGold(bool is_gold) override;
// Depicts this state as an HTML-language string. // Depicts this state as an HTML-language string.
string HTMLRepresentation() const override; string HTMLRepresentation() const override;
...@@ -108,7 +114,7 @@ class SyntaxNetTransitionState ...@@ -108,7 +114,7 @@ class SyntaxNetTransitionState
parent_for_token_.insert(parent_for_token_.begin() + token, parent); parent_for_token_.insert(parent_for_token_.begin() + token, parent);
} }
// Accessor for the underlying nlp_saft::ParserState. // Accessor for the underlying ParserState.
ParserState *parser_state() { return parser_state_.get(); } ParserState *parser_state() { return parser_state_.get(); }
// Accessor for the underlying sentence object. // Accessor for the underlying sentence object.
...@@ -151,9 +157,12 @@ class SyntaxNetTransitionState ...@@ -151,9 +157,12 @@ class SyntaxNetTransitionState
// Trace of the history to produce this state. // Trace of the history to produce this state.
std::unique_ptr<ComponentTrace> trace_; std::unique_ptr<ComponentTrace> trace_;
// True if this state is gold.
bool is_gold_;
}; };
} // namespace dragnn } // namespace dragnn
} // namespace syntaxnet } // namespace syntaxnet
#endif // NLP_SAFT_OPENSOURCE_DRAGNN_COMPONENTS_SYNTAXNET_SYNTAXNET_TRANSITION_STATE_H_ #endif // DRAGNN_COMPONENTS_SYNTAXNET_SYNTAXNET_TRANSITION_STATE_H_
...@@ -134,6 +134,22 @@ TEST_F(SyntaxNetTransitionStateTest, CanSetAndGetScore) { ...@@ -134,6 +134,22 @@ TEST_F(SyntaxNetTransitionStateTest, CanSetAndGetScore) {
EXPECT_EQ(test_state->GetScore(), kNewScore); EXPECT_EQ(test_state->GetScore(), kNewScore);
} }
// Validates the consistency of the goldness setter and getter.
TEST_F(SyntaxNetTransitionStateTest, CanSetAndGetGold) {
// Create and initialize a test state.
MockTransitionState mock_state;
auto test_state = CreateState();
test_state->Init(mock_state);
constexpr bool kOldGold = true;
test_state->SetGold(kOldGold);
EXPECT_EQ(test_state->IsGold(), kOldGold);
constexpr bool kNewGold = false;
test_state->SetGold(kNewGold);
EXPECT_EQ(test_state->IsGold(), kNewGold);
}
// This test ensures that the initializing state's current index is saved // This test ensures that the initializing state's current index is saved
// as the parent beam index of the state being initialized. // as the parent beam index of the state being initialized.
TEST_F(SyntaxNetTransitionStateTest, ReportsParentBeamIndex) { TEST_F(SyntaxNetTransitionStateTest, ReportsParentBeamIndex) {
......
...@@ -13,8 +13,8 @@ ...@@ -13,8 +13,8 @@
// limitations under the License. // limitations under the License.
// ============================================================================= // =============================================================================
#ifndef NLP_SAFT_OPENSOURCE_DRAGNN_COMPONENTS_UTIL_BULK_FEATURE_EXTRACTOR_H_ #ifndef DRAGNN_COMPONENTS_UTIL_BULK_FEATURE_EXTRACTOR_H_
#define NLP_SAFT_OPENSOURCE_DRAGNN_COMPONENTS_UTIL_BULK_FEATURE_EXTRACTOR_H_ #define DRAGNN_COMPONENTS_UTIL_BULK_FEATURE_EXTRACTOR_H_
#include <functional> #include <functional>
#include <utility> #include <utility>
...@@ -107,4 +107,4 @@ class BulkFeatureExtractor { ...@@ -107,4 +107,4 @@ class BulkFeatureExtractor {
} // namespace dragnn } // namespace dragnn
} // namespace syntaxnet } // namespace syntaxnet
#endif // NLP_SAFT_OPENSOURCE_DRAGNN_COMPONENTS_UTIL_BULK_FEATURE_EXTRACTOR_H_ #endif // DRAGNN_COMPONENTS_UTIL_BULK_FEATURE_EXTRACTOR_H_
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment