Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
4364390a
Commit
4364390a
authored
Nov 13, 2017
by
Ivan Bogatyy
Committed by
calberti
Nov 13, 2017
Browse files
Release DRAGNN bulk networks (#2785)
* Release DRAGNN bulk networks
parent
638fd759
Changes
166
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
396 additions
and
128 deletions
+396
-128
research/syntaxnet/Dockerfile
research/syntaxnet/Dockerfile
+7
-8
research/syntaxnet/README.md
research/syntaxnet/README.md
+35
-7
research/syntaxnet/docker-devel/Dockerfile-test
research/syntaxnet/docker-devel/Dockerfile-test
+11
-0
research/syntaxnet/docker-devel/Dockerfile-test-base
research/syntaxnet/docker-devel/Dockerfile-test-base
+91
-0
research/syntaxnet/docker-devel/Dockerfile.min
research/syntaxnet/docker-devel/Dockerfile.min
+7
-7
research/syntaxnet/docker-devel/README.txt
research/syntaxnet/docker-devel/README.txt
+2
-2
research/syntaxnet/dragnn/__init__.py
research/syntaxnet/dragnn/__init__.py
+0
-0
research/syntaxnet/dragnn/components/stateless/BUILD
research/syntaxnet/dragnn/components/stateless/BUILD
+0
-1
research/syntaxnet/dragnn/components/stateless/stateless_component.cc
...taxnet/dragnn/components/stateless/stateless_component.cc
+14
-12
research/syntaxnet/dragnn/components/stateless/stateless_component_test.cc
...t/dragnn/components/stateless/stateless_component_test.cc
+37
-0
research/syntaxnet/dragnn/components/syntaxnet/syntaxnet_component.cc
...taxnet/dragnn/components/syntaxnet/syntaxnet_component.cc
+49
-24
research/syntaxnet/dragnn/components/syntaxnet/syntaxnet_component.h
...ntaxnet/dragnn/components/syntaxnet/syntaxnet_component.h
+16
-8
research/syntaxnet/dragnn/components/syntaxnet/syntaxnet_component_test.cc
...t/dragnn/components/syntaxnet/syntaxnet_component_test.cc
+65
-24
research/syntaxnet/dragnn/components/syntaxnet/syntaxnet_link_feature_extractor.h
...n/components/syntaxnet/syntaxnet_link_feature_extractor.h
+5
-9
research/syntaxnet/dragnn/components/syntaxnet/syntaxnet_link_feature_extractor_test.cc
...onents/syntaxnet/syntaxnet_link_feature_extractor_test.cc
+1
-1
research/syntaxnet/dragnn/components/syntaxnet/syntaxnet_transition_state.cc
...dragnn/components/syntaxnet/syntaxnet_transition_state.cc
+12
-6
research/syntaxnet/dragnn/components/syntaxnet/syntaxnet_transition_state.h
.../dragnn/components/syntaxnet/syntaxnet_transition_state.h
+25
-16
research/syntaxnet/dragnn/components/syntaxnet/syntaxnet_transition_state_test.cc
...n/components/syntaxnet/syntaxnet_transition_state_test.cc
+16
-0
research/syntaxnet/dragnn/components/util/bulk_feature_extractor.h
...syntaxnet/dragnn/components/util/bulk_feature_extractor.h
+3
-3
research/syntaxnet/dragnn/config_builder/__init__.py
research/syntaxnet/dragnn/config_builder/__init__.py
+0
-0
No files found.
research/syntaxnet/Dockerfile
View file @
4364390a
# Java baseimage, for Bazel.
FROM
ubuntu:16.10
FROM
openjdk:8
ENV
SYNTAXNETDIR=/opt/tensorflow PATH=$PATH:/root/bin
ENV
SYNTAXNETDIR=/opt/tensorflow PATH=$PATH:/root/bin
...
@@ -21,13 +20,15 @@ RUN mkdir -p $SYNTAXNETDIR \
...
@@ -21,13 +20,15 @@ RUN mkdir -p $SYNTAXNETDIR \
libopenblas-dev
\
libopenblas-dev
\
libpng-dev
\
libpng-dev
\
libxft-dev
\
libxft-dev
\
patch
\
openjdk-8-jdk
\
python-dev
\
python-dev
\
python-mock
\
python-mock
\
python-pip
\
python-pip
\
python2.7
\
python2.7
\
swig
\
swig
\
unzip
\
vim
\
vim
\
wget
\
zlib1g-dev
\
zlib1g-dev
\
&&
apt-get clean
\
&&
apt-get clean
\
&&
(
rm
-f
/var/cache/apt/archives/
*
.deb
\
&&
(
rm
-f
/var/cache/apt/archives/
*
.deb
\
...
@@ -55,7 +56,7 @@ RUN python -m pip install \
...
@@ -55,7 +56,7 @@ RUN python -m pip install \
--py
--sys-prefix
widgetsnbextension
\
--py
--sys-prefix
widgetsnbextension
\
&&
rm
-rf
/root/.cache/pip /tmp/pip
*
&&
rm
-rf
/root/.cache/pip /tmp/pip
*
# Installs
the latest version of
Bazel.
# Installs Bazel.
RUN
wget
--quiet
https://github.com/bazelbuild/bazel/releases/download/0.5.4/bazel-0.5.4-installer-linux-x86_64.sh
\
RUN
wget
--quiet
https://github.com/bazelbuild/bazel/releases/download/0.5.4/bazel-0.5.4-installer-linux-x86_64.sh
\
&&
chmod
+x bazel-0.5.4-installer-linux-x86_64.sh
\
&&
chmod
+x bazel-0.5.4-installer-linux-x86_64.sh
\
&&
./bazel-0.5.4-installer-linux-x86_64.sh
\
&&
./bazel-0.5.4-installer-linux-x86_64.sh
\
...
@@ -65,13 +66,11 @@ COPY WORKSPACE $SYNTAXNETDIR/syntaxnet/WORKSPACE
...
@@ -65,13 +66,11 @@ COPY WORKSPACE $SYNTAXNETDIR/syntaxnet/WORKSPACE
COPY
tools/bazel.rc $SYNTAXNETDIR/syntaxnet/tools/bazel.rc
COPY
tools/bazel.rc $SYNTAXNETDIR/syntaxnet/tools/bazel.rc
COPY
tensorflow $SYNTAXNETDIR/syntaxnet/tensorflow
COPY
tensorflow $SYNTAXNETDIR/syntaxnet/tensorflow
# Workaround solving the PYTHON_BIN_PATH not found problem
ENV
PYTHON_BIN_PATH=/usr/bin/python
# Compile common TensorFlow targets, which don't depend on DRAGNN / SyntaxNet
# Compile common TensorFlow targets, which don't depend on DRAGNN / SyntaxNet
# source. This makes it more convenient to re-compile DRAGNN / SyntaxNet for
# source. This makes it more convenient to re-compile DRAGNN / SyntaxNet for
# development (though not as convenient as the docker-devel scripts).
# development (though not as convenient as the docker-devel scripts).
RUN
cd
$SYNTAXNETDIR
/syntaxnet/tensorflow
\
RUN
cd
$SYNTAXNETDIR
/syntaxnet/tensorflow
\
&&
.
/configure CPU
\
&&
tensorflow/tools/ci_build/builds
/configure
d
CPU
\
&&
cd
$SYNTAXNETDIR
/syntaxnet
\
&&
cd
$SYNTAXNETDIR
/syntaxnet
\
&&
bazel build
-c
opt @org_tensorflow//tensorflow:tensorflow_py
&&
bazel build
-c
opt @org_tensorflow//tensorflow:tensorflow_py
...
@@ -92,4 +91,4 @@ EXPOSE 8888
...
@@ -92,4 +91,4 @@ EXPOSE 8888
COPY
examples $SYNTAXNETDIR/syntaxnet/examples
COPY
examples $SYNTAXNETDIR/syntaxnet/examples
# Todo: Move this earlier in the file (don't want to invalidate caches for now).
# Todo: Move this earlier in the file (don't want to invalidate caches for now).
CMD
/bin/bash -c "bazel-bin/dragnn/tools/oss_notebook_launcher notebook --debug --notebook-dir=/opt/tensorflow/syntaxnet/examples
--allow-root
"
CMD
/bin/bash -c "bazel-bin/dragnn/tools/oss_notebook_launcher notebook --debug --notebook-dir=/opt/tensorflow/syntaxnet/examples"
research/syntaxnet/README.md
View file @
4364390a
...
@@ -23,8 +23,8 @@ This repository is largely divided into two sub-packages:
...
@@ -23,8 +23,8 @@ This repository is largely divided into two sub-packages:
[
documentation
](
g3doc/DRAGNN.md
)
,
[
documentation
](
g3doc/DRAGNN.md
)
,
[
paper
](
https://arxiv.org/pdf/1703.04474.pdf
)
**
implements Dynamic Recurrent
[
paper
](
https://arxiv.org/pdf/1703.04474.pdf
)
**
implements Dynamic Recurrent
Acyclic Graphical Neural Networks (DRAGNN), a framework for building
Acyclic Graphical Neural Networks (DRAGNN), a framework for building
multi-task, fully dynamically constructed computation graphs. Practically,
we
multi-task, fully dynamically constructed computation graphs. Practically,
use DRAGNN to extend our prior work from
[
Andor et al.
we
use DRAGNN to extend our prior work from
[
Andor et al.
(2016)
](
http://arxiv.org/abs/1603.06042
)
with end-to-end, deep recurrent
(2016)
](
http://arxiv.org/abs/1603.06042
)
with end-to-end, deep recurrent
models and to provide a much easier to use interface to SyntaxNet.
*
DRAGNN
models and to provide a much easier to use interface to SyntaxNet.
*
DRAGNN
is designed first and foremost as a Python library, and therefore much
is designed first and foremost as a Python library, and therefore much
...
@@ -54,20 +54,47 @@ There are three ways to use SyntaxNet:
...
@@ -54,20 +54,47 @@ There are three ways to use SyntaxNet:
### Docker installation
### Docker installation
_This process takes ~10 minutes._
The simplest way to get started with DRAGNN is by loading our Docker container.
The simplest way to get started with DRAGNN is by loading our Docker container.
[
Here
](
g3doc/CLOUD.md
)
is a tutorial for running the DRAGNN container on
[
Here
](
g3doc/CLOUD.md
)
is a tutorial for running the DRAGNN container on
[
GCP
](
https://cloud.google.com
)
(
just
as applicable to your own computer).
[
GCP
](
https://cloud.google.com
)
(
just
as applicable to your own computer).
### Ubuntu 16.10+ binary installation
_This process takes ~5 minutes, but is only compatible with Linux using GNU libc
3.
4.22 and above (e.g. Ubuntu 16.10)._
Binary wheel packages are provided for TensorFlow and SyntaxNet. If you do not
need to write new binary TensorFlow ops, these should suffice.
*
`apt-get install -y graphviz libgraphviz-dev libopenblas-base libpng16-16
libxft2 python-pip python-mock`
*
`pip install pygraphviz
--install-option="--include-path=/usr/include/graphviz"
--install-option="--library-path=/usr/lib/graphviz/"`
*
`pip install 'ipython<6.0' protobuf numpy scipy jupyter
syntaxnet-with-tensorflow`
*
`python -m jupyter_core.command nbextension enable --py --sys-prefix
widgetsnbextension`
You can test that binary modules can be successfully imported by running,
*
`python -c 'import dragnn.python.load_dragnn_cc_impl,
syntaxnet.load_parser_ops'`
### Manual installation
### Manual installation
_This process takes 1-2 hours._
Running and training SyntaxNet/DRAGNN models requires building this package from
Running and training SyntaxNet/DRAGNN models requires building this package from
source. You'll need to install:
source. You'll need to install:
*
python 2.7:
*
python 2.7:
*
Python 3 support is not available yet
*
Python 3 support is not available yet
*
bazel:
*
bazel
0.5.4
:
*
Follow the instructions
[
here
](
http://bazel.build/docs/install.html
)
*
Follow the instructions
[
here
](
http://bazel.build/docs/install.html
)
*
Alternately, Download bazel
<
.
deb
>
from
*
Alternately, Download bazel
0.5.4
<
.
deb
>
from
[
https://github.com/bazelbuild/bazel/releases
](
https://github.com/bazelbuild/bazel/releases
)
[
https://github.com/bazelbuild/bazel/releases
](
https://github.com/bazelbuild/bazel/releases
)
for your system configuration.
for your system configuration.
*
Install it using the command: sudo dpkg -i
<
.
deb
file
>
*
Install it using the command: sudo dpkg -i
<
.
deb
file
>
...
@@ -103,9 +130,12 @@ following commands:
...
@@ -103,9 +130,12 @@ following commands:
bazel
test
--linkopt
=
-headerpad_max_install_names
\
bazel
test
--linkopt
=
-headerpad_max_install_names
\
dragnn/... syntaxnet/... util/utf8/...
dragnn/... syntaxnet/... util/utf8/...
```
```
Bazel should complete reporting all tests passed.
Bazel should complete reporting all tests passed.
Now you can install the SyntaxNet and DRAGNN Python modules with the following commands:
Now you can install the SyntaxNet and DRAGNN Python modules with the following
commands:
```
shell
```
shell
mkdir
/tmp/syntaxnet_pkg
mkdir
/tmp/syntaxnet_pkg
bazel-bin/dragnn/tools/build_pip_package
--output-dir
=
/tmp/syntaxnet_pkg
bazel-bin/dragnn/tools/build_pip_package
--output-dir
=
/tmp/syntaxnet_pkg
...
@@ -116,8 +146,6 @@ Now you can install the SyntaxNet and DRAGNN Python modules with the following c
...
@@ -116,8 +146,6 @@ Now you can install the SyntaxNet and DRAGNN Python modules with the following c
To build SyntaxNet with GPU support please refer to the instructions in
To build SyntaxNet with GPU support please refer to the instructions in
[
issues/248
](
https://github.com/tensorflow/models/issues/248
)
.
[
issues/248
](
https://github.com/tensorflow/models/issues/248
)
.
**Note:**
If you are running Docker on OSX, make sure that you have enough
**Note:**
If you are running Docker on OSX, make sure that you have enough
memory allocated for your Docker VM.
memory allocated for your Docker VM.
...
...
research/syntaxnet/docker-devel/Dockerfile-test
0 → 100644
View file @
4364390a
FROM dragnn-oss-test-base:latest
RUN rm -rf \
$SYNTAXNETDIR/syntaxnet/dragnn \
$SYNTAXNETDIR/syntaxnet/syntaxnet \
$SYNTAXNETDIR/syntaxnet/third_party \
$SYNTAXNETDIR/syntaxnet/util/utf8
COPY dragnn $SYNTAXNETDIR/syntaxnet/dragnn
COPY syntaxnet $SYNTAXNETDIR/syntaxnet/syntaxnet
COPY third_party $SYNTAXNETDIR/syntaxnet/third_party
COPY util/utf8 $SYNTAXNETDIR/syntaxnet/util/utf8
research/syntaxnet/docker-devel/Dockerfile-test-base
0 → 100644
View file @
4364390a
FROM ubuntu:16.10
ENV SYNTAXNETDIR=/opt/tensorflow PATH=$PATH:/root/bin
# Install system packages. This doesn't include everything the TensorFlow
# dockerfile specifies, so if anything goes awry, maybe install more packages
# from there. Also, running apt-get clean before further commands will make the
# Docker images smaller.
RUN mkdir -p $SYNTAXNETDIR \
&& cd $SYNTAXNETDIR \
&& apt-get update \
&& apt-get install -y \
file \
git \
graphviz \
libcurl3-dev \
libfreetype6-dev \
libgraphviz-dev \
liblapack-dev \
libopenblas-dev \
libpng-dev \
libxft-dev \
openjdk-8-jdk \
python-dev \
python-mock \
python-pip \
python2.7 \
swig \
unzip \
vim \
wget \
zlib1g-dev \
&& apt-get clean \
&& (rm -f /var/cache/apt/archives/*.deb \
/var/cache/apt/archives/partial/*.deb /var/cache/apt/*.bin || true)
# Install common Python dependencies. Similar to above, remove caches
# afterwards to help keep Docker images smaller.
RUN pip install --ignore-installed pip \
&& python -m pip install numpy \
&& rm -rf /root/.cache/pip /tmp/pip*
RUN python -m pip install \
asciitree \
ipykernel \
jupyter \
matplotlib \
pandas \
protobuf \
scipy \
sklearn \
&& python -m ipykernel.kernelspec \
&& python -m pip install pygraphviz \
--install-option="--include-path=/usr/include/graphviz" \
--install-option="--library-path=/usr/lib/graphviz/" \
&& python -m jupyter_core.command nbextension enable \
--py --sys-prefix widgetsnbextension \
&& rm -rf /root/.cache/pip /tmp/pip*
# Installs Bazel.
RUN wget --quiet https://github.com/bazelbuild/bazel/releases/download/0.5.3/bazel-0.5.3-installer-linux-x86_64.sh \
&& chmod +x bazel-0.5.3-installer-linux-x86_64.sh \
&& JAVA_HOME=/usr/lib/jvm/java-8-openjdk-amd64/ ./bazel-0.5.3-installer-linux-x86_64.sh \
&& rm ./bazel-0.5.3-installer-linux-x86_64.sh
COPY WORKSPACE $SYNTAXNETDIR/syntaxnet/WORKSPACE
COPY tools/bazel.rc $SYNTAXNETDIR/syntaxnet/tools/bazel.rc
# Compile common TensorFlow targets, which don't depend on DRAGNN / SyntaxNet
# source. This makes it more convenient to re-compile DRAGNN / SyntaxNet for
# development (though not as convenient as the docker-devel scripts).
RUN cd $SYNTAXNETDIR/syntaxnet \
&& git clone --branch r1.3 --recurse-submodules https://github.com/tensorflow/tensorflow \
&& cd tensorflow \
# This line removes a bad archive target which causes Tensorflow install
# to fail.
&& sed -i '\@https://github.com/google/protobuf/archive/0b059a3d8a8f8aa40dde7bea55edca4ec5dfea66.tar.gz@d' tensorflow/workspace.bzl \
&& tensorflow/tools/ci_build/builds/configured CPU \\
&& cd $SYNTAXNETDIR/syntaxnet \
&& bazel build -c opt @org_tensorflow//tensorflow:tensorflow_py
# Just copy the code and run tests. The build and test flags differ enough that
# doing a normal build of TensorFlow targets doesn't save much test time.
WORKDIR $SYNTAXNETDIR/syntaxnet
COPY dragnn $SYNTAXNETDIR/syntaxnet/dragnn
COPY syntaxnet $SYNTAXNETDIR/syntaxnet/syntaxnet
COPY third_party $SYNTAXNETDIR/syntaxnet/third_party
COPY util/utf8 $SYNTAXNETDIR/syntaxnet/util/utf8
# Doesn't matter if the tests pass or not, since we're going to re-copy over the
# code.
RUN bazel test -c opt ... || true
research/syntaxnet/docker-devel/Dockerfile.min
View file @
4364390a
# You need to build wheels before building this image. Please consult
# You need to build wheels before building this image. Please consult
# docker-devel/README.txt.
# docker-devel/README.txt.
# This is the base of the openjdk image.
#
#
# It might be more efficient to use a minimal distribution, like Alpine. But
# It might be more efficient to use a minimal distribution, like Alpine. But
# the upside of this being popular is that people might already have it.
# the upside of this being popular is that people might already have it.
FROM bu
ildpack-deps:jessie-curl
FROM
u
bu
ntu:16.10
ENV SYNTAXNETDIR=/opt/tensorflow PATH=$PATH:/root/bin
ENV SYNTAXNETDIR=/opt/tensorflow PATH=$PATH:/root/bin
...
@@ -19,7 +17,7 @@ RUN apt-get update \
...
@@ -19,7 +17,7 @@ RUN apt-get update \
libgraphviz-dev \
libgraphviz-dev \
liblapack3 \
liblapack3 \
libopenblas-base \
libopenblas-base \
libpng1
2-0
\
libpng1
6-16
\
libxft2 \
libxft2 \
python-dev \
python-dev \
python-mock \
python-mock \
...
@@ -48,11 +46,13 @@ RUN python -m pip install \
...
@@ -48,11 +46,13 @@ RUN python -m pip install \
&& python -m pip install pygraphviz \
&& python -m pip install pygraphviz \
--install-option="--include-path=/usr/include/graphviz" \
--install-option="--include-path=/usr/include/graphviz" \
--install-option="--library-path=/usr/lib/graphviz/" \
--install-option="--library-path=/usr/lib/graphviz/" \
&& python -m jupyter_core.command nbextension enable \
--py --sys-prefix widgetsnbextension \
&& rm -rf /root/.cache/pip /tmp/pip*
&& rm -rf /root/.cache/pip /tmp/pip*
COPY syntaxnet_with_tensorflow-0.2-cp27-
none
-linux_x86_64.whl $SYNTAXNETDIR/
COPY syntaxnet_with_tensorflow-0.2-cp27-
cp27mu
-linux_x86_64.whl $SYNTAXNETDIR/
RUN python -m pip install \
RUN python -m pip install \
$SYNTAXNETDIR/syntaxnet_with_tensorflow-0.2-cp27-
none
-linux_x86_64.whl \
$SYNTAXNETDIR/syntaxnet_with_tensorflow-0.2-cp27-
cp27mu
-linux_x86_64.whl \
&& rm -rf /root/.cache/pip /tmp/pip*
&& rm -rf /root/.cache/pip /tmp/pip*
# This makes the IP exposed actually "*"; we'll do host restrictions by passing
# This makes the IP exposed actually "*"; we'll do host restrictions by passing
...
@@ -63,4 +63,4 @@ EXPOSE 8888
...
@@ -63,4 +63,4 @@ EXPOSE 8888
# This does not need to be compiled, only copied.
# This does not need to be compiled, only copied.
COPY examples $SYNTAXNETDIR/syntaxnet/examples
COPY examples $SYNTAXNETDIR/syntaxnet/examples
# For some reason, this works if we run it in a bash shell :/ :/ :/
# For some reason, this works if we run it in a bash shell :/ :/ :/
CMD /bin/bash -c "python -m jupyter_core.command notebook --debug --notebook-dir=/opt/tensorflow/syntaxnet/examples"
CMD /bin/bash -c "python -m jupyter_core.command notebook --debug --notebook-dir=/opt/tensorflow/syntaxnet/examples
--allow-root
"
research/syntaxnet/docker-devel/README.txt
View file @
4364390a
...
@@ -43,11 +43,11 @@ Step 3: Building the development image
...
@@ -43,11 +43,11 @@ Step 3: Building the development image
First, ensure you have the file
First, ensure you have the file
syntaxnet_with_tensorflow-0.2-cp27-
none
-linux_x86_64.whl
syntaxnet_with_tensorflow-0.2-cp27-
cp27mu
-linux_x86_64.whl
in your working directory, from step 2. Then run,
in your working directory, from step 2. Then run,
docker build -t dragnn-oss:latest-minimal -f docker-devel/Dockerfile.min
docker build -t dragnn-oss:latest-minimal -f docker-devel/Dockerfile.min
.
If the filename changes (e.g. you are on a different architecture), just update
If the filename changes (e.g. you are on a different architecture), just update
Dockerfile.min.
Dockerfile.min.
...
...
research/syntaxnet/dragnn/__init__.py
0 → 100644
View file @
4364390a
research/syntaxnet/dragnn/components/stateless/BUILD
View file @
4364390a
...
@@ -10,7 +10,6 @@ cc_library(
...
@@ -10,7 +10,6 @@ cc_library(
"//dragnn/core:component_registry"
,
"//dragnn/core:component_registry"
,
"//dragnn/core/interfaces:component"
,
"//dragnn/core/interfaces:component"
,
"//dragnn/core/interfaces:transition_state"
,
"//dragnn/core/interfaces:transition_state"
,
"//dragnn/io:sentence_input_batch"
,
"//dragnn/protos:data_proto"
,
"//dragnn/protos:data_proto"
,
"//syntaxnet:base"
,
"//syntaxnet:base"
,
],
],
...
...
research/syntaxnet/dragnn/components/stateless/stateless_component.cc
View file @
4364390a
...
@@ -16,7 +16,6 @@
...
@@ -16,7 +16,6 @@
#include "dragnn/core/component_registry.h"
#include "dragnn/core/component_registry.h"
#include "dragnn/core/interfaces/component.h"
#include "dragnn/core/interfaces/component.h"
#include "dragnn/core/interfaces/transition_state.h"
#include "dragnn/core/interfaces/transition_state.h"
#include "dragnn/io/sentence_input_batch.h"
#include "dragnn/protos/data.pb.h"
#include "dragnn/protos/data.pb.h"
#include "syntaxnet/base.h"
#include "syntaxnet/base.h"
...
@@ -25,7 +24,8 @@ namespace dragnn {
...
@@ -25,7 +24,8 @@ namespace dragnn {
namespace
{
namespace
{
// A component that does not create its own transition states; instead, it
// A component that does not create its own transition states; instead, it
// simply forwards the states of the previous component. Does not support all
// simply forwards the states of the previous component. Requires that some
// previous component has converted the input batch. Does not support all
// methods. Intended for "compute-only" bulk components that only use linked
// methods. Intended for "compute-only" bulk components that only use linked
// features, which use only a small subset of DRAGNN functionality.
// features, which use only a small subset of DRAGNN functionality.
class
StatelessComponent
:
public
Component
{
class
StatelessComponent
:
public
Component
{
...
@@ -38,8 +38,7 @@ class StatelessComponent : public Component {
...
@@ -38,8 +38,7 @@ class StatelessComponent : public Component {
void
InitializeData
(
void
InitializeData
(
const
std
::
vector
<
std
::
vector
<
const
TransitionState
*>>
&
parent_states
,
const
std
::
vector
<
std
::
vector
<
const
TransitionState
*>>
&
parent_states
,
int
max_beam_size
,
InputBatchCache
*
input_data
)
override
{
int
max_beam_size
,
InputBatchCache
*
input_data
)
override
{
// Must use SentenceInputBatch to match SyntaxNetComponent.
batch_size_
=
input_data
->
Size
();
batch_size_
=
input_data
->
GetAs
<
SentenceInputBatch
>
()
->
data
()
->
size
();
beam_size_
=
max_beam_size
;
beam_size_
=
max_beam_size
;
parent_states_
=
parent_states
;
parent_states_
=
parent_states
;
...
@@ -84,31 +83,34 @@ class StatelessComponent : public Component {
...
@@ -84,31 +83,34 @@ class StatelessComponent : public Component {
LOG
(
FATAL
)
<<
"["
<<
name_
<<
"] Method not supported"
;
LOG
(
FATAL
)
<<
"["
<<
name_
<<
"] Method not supported"
;
return
nullptr
;
return
nullptr
;
}
}
void
AdvanceFromPrediction
(
const
float
transition_matrix
[]
,
bool
AdvanceFromPrediction
(
const
float
*
transition_matrix
,
int
num_items
,
int
matrix_length
)
override
{
int
num_actions
)
override
{
LOG
(
FATAL
)
<<
"["
<<
name_
<<
"]
Method
not supported"
;
LOG
(
FATAL
)
<<
"["
<<
name_
<<
"]
AdvanceFromPrediction
not supported"
;
}
}
void
AdvanceFromOracle
()
override
{
void
AdvanceFromOracle
()
override
{
LOG
(
FATAL
)
<<
"["
<<
name_
<<
"]
Method
not supported"
;
LOG
(
FATAL
)
<<
"["
<<
name_
<<
"]
AdvanceFromOracle
not supported"
;
}
}
std
::
vector
<
std
::
vector
<
int
>>
GetOracleLabels
()
const
override
{
std
::
vector
<
std
::
vector
<
int
>>
GetOracleLabels
()
const
override
{
LOG
(
FATAL
)
<<
"["
<<
name_
<<
"] Method not supported"
;
LOG
(
FATAL
)
<<
"["
<<
name_
<<
"] Method not supported"
;
return
{};
}
}
int
GetFixedFeatures
(
std
::
function
<
int32
*
(
int
)
>
allocate_indices
,
int
GetFixedFeatures
(
std
::
function
<
int32
*
(
int
)
>
allocate_indices
,
std
::
function
<
int64
*
(
int
)
>
allocate_ids
,
std
::
function
<
int64
*
(
int
)
>
allocate_ids
,
std
::
function
<
float
*
(
int
)
>
allocate_weights
,
std
::
function
<
float
*
(
int
)
>
allocate_weights
,
int
channel_id
)
const
override
{
int
channel_id
)
const
override
{
LOG
(
FATAL
)
<<
"["
<<
name_
<<
"] Method not supported"
;
LOG
(
FATAL
)
<<
"["
<<
name_
<<
"] Method not supported"
;
return
0
;
}
}
int
BulkGetFixedFeatures
(
const
BulkFeatureExtractor
&
extractor
)
override
{
int
BulkGetFixedFeatures
(
const
BulkFeatureExtractor
&
extractor
)
override
{
LOG
(
FATAL
)
<<
"["
<<
name_
<<
"] Method not supported"
;
LOG
(
FATAL
)
<<
"["
<<
name_
<<
"] Method not supported"
;
return
0
;
}
}
void
BulkEmbedFixedFeatures
(
int
batch_size_padding
,
int
num_steps_padding
,
int
output_array_size
,
const
vector
<
const
float
*>
&
per_channel_embeddings
,
float
*
embedding_output
)
override
{
LOG
(
FATAL
)
<<
"["
<<
name_
<<
"] Method not supported"
;
}
std
::
vector
<
LinkFeatures
>
GetRawLinkFeatures
(
int
channel_id
)
const
override
{
std
::
vector
<
LinkFeatures
>
GetRawLinkFeatures
(
int
channel_id
)
const
override
{
LOG
(
FATAL
)
<<
"["
<<
name_
<<
"] Method not supported"
;
LOG
(
FATAL
)
<<
"["
<<
name_
<<
"] Method not supported"
;
return
{};
}
}
void
AddTranslatedLinkFeaturesToTrace
(
void
AddTranslatedLinkFeaturesToTrace
(
const
std
::
vector
<
LinkFeatures
>
&
features
,
int
channel_id
)
override
{
const
std
::
vector
<
LinkFeatures
>
&
features
,
int
channel_id
)
override
{
...
...
research/syntaxnet/dragnn/components/stateless/stateless_component_test.cc
View file @
4364390a
...
@@ -18,6 +18,7 @@
...
@@ -18,6 +18,7 @@
#include "dragnn/core/test/generic.h"
#include "dragnn/core/test/generic.h"
#include "dragnn/core/test/mock_transition_state.h"
#include "dragnn/core/test/mock_transition_state.h"
#include "dragnn/io/sentence_input_batch.h"
#include "dragnn/io/sentence_input_batch.h"
#include "dragnn/protos/data.pb.h"
#include "syntaxnet/base.h"
#include "syntaxnet/base.h"
#include "syntaxnet/sentence.pb.h"
#include "syntaxnet/sentence.pb.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/errors.h"
...
@@ -119,6 +120,10 @@ class StatelessComponentTest : public ::testing::Test {
...
@@ -119,6 +120,10 @@ class StatelessComponentTest : public ::testing::Test {
CHECK
(
TextFormat
::
ParseFromString
(
kMasterSpec
,
&
master_spec
));
CHECK
(
TextFormat
::
ParseFromString
(
kMasterSpec
,
&
master_spec
));
data_
.
reset
(
new
InputBatchCache
(
data
));
data_
.
reset
(
new
InputBatchCache
(
data
));
// The stateless component does not use any particular input batch type, and
// relies on the preceding components to convert the input batch.
data_
->
GetAs
<
SentenceInputBatch
>
();
// Create a parser component with the specified beam size.
// Create a parser component with the specified beam size.
std
::
unique_ptr
<
Component
>
parser_component
(
std
::
unique_ptr
<
Component
>
parser_component
(
Component
::
Create
(
"StatelessComponent"
));
Component
::
Create
(
"StatelessComponent"
));
...
@@ -167,5 +172,37 @@ TEST_F(StatelessComponentTest, ForwardsTransitionStates) {
...
@@ -167,5 +172,37 @@ TEST_F(StatelessComponentTest, ForwardsTransitionStates) {
EXPECT_EQ
(
parent_states
,
forwarded_states
);
EXPECT_EQ
(
parent_states
,
forwarded_states
);
}
}
TEST_F
(
StatelessComponentTest
,
UnimplementedMethodsDie
)
{
MockTransitionState
mock_state_1
,
mock_state_2
,
mock_state_3
;
const
std
::
vector
<
std
::
vector
<
const
TransitionState
*>>
parent_states
;
std
::
vector
<
string
>
data
;
for
(
const
string
&
textproto
:
{
kSentence0
,
kSentence1
,
kLongSentence
})
{
Sentence
sentence
;
CHECK
(
TextFormat
::
ParseFromString
(
textproto
,
&
sentence
));
data
.
emplace_back
();
CHECK
(
sentence
.
SerializeToString
(
&
data
.
back
()));
}
const
int
kBeamSize
=
2
;
auto
test_parser
=
CreateParser
(
kBeamSize
,
parent_states
,
data
);
EXPECT_TRUE
(
test_parser
->
IsReady
());
EXPECT_DEATH
(
test_parser
->
AdvanceFromPrediction
({},
0
,
0
),
"AdvanceFromPrediction not supported"
);
EXPECT_DEATH
(
test_parser
->
AdvanceFromOracle
(),
"AdvanceFromOracle not supported"
);
EXPECT_DEATH
(
test_parser
->
GetOracleLabels
(),
"Method not supported"
);
EXPECT_DEATH
(
test_parser
->
GetFixedFeatures
(
nullptr
,
nullptr
,
nullptr
,
0
),
"Method not supported"
);
BulkFeatureExtractor
extractor
(
nullptr
,
nullptr
,
nullptr
);
EXPECT_DEATH
(
test_parser
->
BulkEmbedFixedFeatures
(
0
,
0
,
0
,
{
nullptr
},
nullptr
),
"Method not supported"
);
EXPECT_DEATH
(
test_parser
->
BulkGetFixedFeatures
(
extractor
),
"Method not supported"
);
EXPECT_DEATH
(
test_parser
->
GetRawLinkFeatures
(
0
),
"Method not supported"
);
EXPECT_DEATH
(
test_parser
->
AddTranslatedLinkFeaturesToTrace
({},
0
),
"Method not supported"
);
}
}
// namespace dragnn
}
// namespace dragnn
}
// namespace syntaxnet
}
// namespace syntaxnet
research/syntaxnet/dragnn/components/syntaxnet/syntaxnet_component.cc
View file @
4364390a
...
@@ -28,6 +28,7 @@
...
@@ -28,6 +28,7 @@
#include "syntaxnet/sparse.pb.h"
#include "syntaxnet/sparse.pb.h"
#include "syntaxnet/task_spec.pb.h"
#include "syntaxnet/task_spec.pb.h"
#include "syntaxnet/utils.h"
#include "syntaxnet/utils.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/logging.h"
namespace
syntaxnet
{
namespace
syntaxnet
{
...
@@ -105,7 +106,7 @@ void SyntaxNetComponent::InitializeComponent(const ComponentSpec &spec) {
...
@@ -105,7 +106,7 @@ void SyntaxNetComponent::InitializeComponent(const ComponentSpec &spec) {
dims
.
push_back
(
StrCat
(
channel
.
embedding_dim
()));
dims
.
push_back
(
StrCat
(
channel
.
embedding_dim
()));
}
}
context
.
SetParameter
(
"neurosis_feature_syntax_version"
,
"2"
);
context
.
SetParameter
(
"brain_parser_embedding_dims"
,
utils
::
Join
(
dims
,
";"
));
context
.
SetParameter
(
"brain_parser_embedding_dims"
,
utils
::
Join
(
dims
,
";"
));
context
.
SetParameter
(
"brain_parser_predicate_maps"
,
context
.
SetParameter
(
"brain_parser_predicate_maps"
,
utils
::
Join
(
predicate_maps
,
";"
));
utils
::
Join
(
predicate_maps
,
";"
));
...
@@ -187,8 +188,9 @@ std::unique_ptr<Beam<SyntaxNetTransitionState>> SyntaxNetComponent::CreateBeam(
...
@@ -187,8 +188,9 @@ std::unique_ptr<Beam<SyntaxNetTransitionState>> SyntaxNetComponent::CreateBeam(
return
this
->
IsFinal
(
state
);
return
this
->
IsFinal
(
state
);
};
};
auto
oracle_function
=
[
this
](
SyntaxNetTransitionState
*
state
)
{
auto
oracle_function
=
[
this
](
SyntaxNetTransitionState
*
state
)
{
VLOG
(
2
)
<<
"oracle_function action:"
<<
this
->
GetOracleLabel
(
state
);
VLOG
(
2
)
<<
"oracle_function action:"
return
this
->
GetOracleLabel
(
state
);
<<
tensorflow
::
str_util
::
Join
(
this
->
GetOracleVector
(
state
),
", "
);
return
this
->
GetOracleVector
(
state
);
};
};
auto
beam_ptr
=
beam
.
get
();
auto
beam_ptr
=
beam
.
get
();
auto
advance_function
=
[
this
,
beam_ptr
](
SyntaxNetTransitionState
*
state
,
auto
advance_function
=
[
this
,
beam_ptr
](
SyntaxNetTransitionState
*
state
,
...
@@ -335,25 +337,32 @@ std::function<int(int, int, int)> SyntaxNetComponent::GetStepLookupFunction(
...
@@ -335,25 +337,32 @@ std::function<int(int, int, int)> SyntaxNetComponent::GetStepLookupFunction(
}
}
}
}
void
SyntaxNetComponent
::
AdvanceFromPrediction
(
const
float
transition_matrix
[],
bool
SyntaxNetComponent
::
AdvanceFromPrediction
(
const
float
*
transition_matrix
,
int
transition_matrix_length
)
{
int
num_items
,
int
num_actions
)
{
VLOG
(
2
)
<<
"Advancing from prediction."
;
VLOG
(
2
)
<<
"Advancing from prediction, component = "
<<
spec_
.
name
();
int
matrix_index
=
0
;
const
int
num_static_actions
=
int
num_labels
=
transition_system_
->
NumActions
(
label_map_
->
Size
());
transition_system_
->
NumActions
(
label_map_
->
Size
());
if
(
num_static_actions
!=
ParserTransitionSystem
::
kDynamicNumActions
)
{
CHECK_EQ
(
num_static_actions
,
num_actions
)
<<
"["
<<
spec_
.
name
()
<<
"] static action set does not match transition matrix"
;
}
for
(
int
i
=
0
;
i
<
batch_
.
size
();
++
i
)
{
for
(
int
i
=
0
;
i
<
batch_
.
size
();
++
i
)
{
int
max_beam_size
=
batch_
.
at
(
i
)
->
max_size
();
const
int
size
=
num_actions
*
batch_
[
i
]
->
max_size
();
int
matrix_size
=
num_labels
*
max_beam_size
;
if
(
!
batch_
[
i
]
->
IsTerminal
())
{
CHECK_LE
(
matrix_index
+
matrix_size
,
transition_matrix_length
);
bool
success
=
batch_
[
i
]
->
AdvanceFromPrediction
(
transition_matrix
,
size
,
if
(
!
batch_
.
at
(
i
)
->
IsTerminal
())
{
num_actions
);
batch_
.
at
(
i
)
->
AdvanceFromPrediction
(
&
transition_matrix
[
matrix_index
],
if
(
!
success
)
{
matrix_size
,
num_labels
);
return
false
;
}
}
}
matrix_index
+=
num_labels
*
max_beam_
size
;
transition_matrix
+=
size
;
}
}
return
true
;
}
}
void
SyntaxNetComponent
::
AdvanceFromOracle
()
{
void
SyntaxNetComponent
::
AdvanceFromOracle
()
{
VLOG
(
2
)
<<
"Advancing from oracle
."
;
VLOG
(
2
)
<<
"Advancing from oracle
, component = "
<<
spec_
.
name
()
;
for
(
auto
&
beam
:
batch_
)
{
for
(
auto
&
beam
:
batch_
)
{
beam
->
AdvanceFromOracle
();
beam
->
AdvanceFromOracle
();
}
}
...
@@ -404,8 +413,18 @@ int SyntaxNetComponent::GetFixedFeatures(
...
@@ -404,8 +413,18 @@ int SyntaxNetComponent::GetFixedFeatures(
features
.
emplace_back
(
f
);
features
.
emplace_back
(
f
);
if
(
do_tracing_
)
{
if
(
do_tracing_
)
{
FixedFeatures
fixed_features
;
FixedFeatures
fixed_features
;
for
(
const
string
&
name
:
f
.
description
())
{
CHECK_EQ
(
f
.
description_size
(),
f
.
id_size
());
fixed_features
.
add_value_name
(
name
);
CHECK
(
f
.
weight_size
()
==
0
||
f
.
weight_size
()
==
f
.
id_size
());
const
bool
has_weights
=
f
.
weight_size
()
!=
0
;
for
(
int
i
=
0
;
i
<
f
.
description_size
();
++
i
)
{
if
(
has_weights
)
{
fixed_features
.
add_value_name
(
StrCat
(
"id: "
,
f
.
id
(
i
),
" name: "
,
f
.
description
(
i
),
" weight: "
,
f
.
weight
(
i
)));
}
else
{
fixed_features
.
add_value_name
(
StrCat
(
"id: "
,
f
.
id
(
i
),
" name: "
,
f
.
description
(
i
)));
}
}
}
fixed_features
.
set_feature_name
(
""
);
fixed_features
.
set_feature_name
(
""
);
auto
*
trace
=
GetLastStepInTrace
(
state
->
mutable_trace
());
auto
*
trace
=
GetLastStepInTrace
(
state
->
mutable_trace
());
...
@@ -522,8 +541,8 @@ int SyntaxNetComponent::BulkGetFixedFeatures(
...
@@ -522,8 +541,8 @@ int SyntaxNetComponent::BulkGetFixedFeatures(
// This would be a good place to add threading.
// This would be a good place to add threading.
for
(
int
channel_id
=
0
;
channel_id
<
num_channels
;
++
channel_id
)
{
for
(
int
channel_id
=
0
;
channel_id
<
num_channels
;
++
channel_id
)
{
int
feature_count
=
feature_counts
[
channel_id
];
int
feature_count
=
feature_counts
[
channel_id
];
LOG
(
INFO
)
<<
"Feature count is "
<<
feature_count
<<
" for channel "
V
LOG
(
2
)
<<
"Feature count is "
<<
feature_count
<<
" for channel "
<<
channel_id
;
<<
channel_id
;
int32
*
indices_tensor
=
int32
*
indices_tensor
=
extractor
.
AllocateIndexMemory
(
channel_id
,
feature_count
);
extractor
.
AllocateIndexMemory
(
channel_id
,
feature_count
);
int64
*
ids_tensor
=
extractor
.
AllocateIdMemory
(
channel_id
,
feature_count
);
int64
*
ids_tensor
=
extractor
.
AllocateIdMemory
(
channel_id
,
feature_count
);
...
@@ -603,7 +622,9 @@ std::vector<std::vector<int>> SyntaxNetComponent::GetOracleLabels() const {
...
@@ -603,7 +622,9 @@ std::vector<std::vector<int>> SyntaxNetComponent::GetOracleLabels() const {
for
(
int
beam_idx
=
0
;
beam_idx
<
beam
->
size
();
++
beam_idx
)
{
for
(
int
beam_idx
=
0
;
beam_idx
<
beam
->
size
();
++
beam_idx
)
{
// Get the raw link features from the linked feature extractor.
// Get the raw link features from the linked feature extractor.
auto
state
=
beam
->
beam_state
(
beam_idx
);
auto
state
=
beam
->
beam_state
(
beam_idx
);
oracle_labels
.
back
().
push_back
(
GetOracleLabel
(
state
));
// Arbitrarily choose the first vector element.
oracle_labels
.
back
().
push_back
(
GetOracleVector
(
state
).
front
());
}
}
}
}
return
oracle_labels
;
return
oracle_labels
;
...
@@ -661,13 +682,17 @@ bool SyntaxNetComponent::IsFinal(SyntaxNetTransitionState *state) const {
...
@@ -661,13 +682,17 @@ bool SyntaxNetComponent::IsFinal(SyntaxNetTransitionState *state) const {
return
transition_system_
->
IsFinalState
(
*
(
state
->
parser_state
()));
return
transition_system_
->
IsFinalState
(
*
(
state
->
parser_state
()));
}
}
int
SyntaxNetComponent
::
GetOracleLabel
(
SyntaxNetTransitionState
*
state
)
const
{
std
::
vector
<
int
>
SyntaxNetComponent
::
GetOracleVector
(
SyntaxNetTransitionState
*
state
)
const
{
if
(
IsFinal
(
state
))
{
if
(
IsFinal
(
state
))
{
// It is not permitted to request an oracle label from a sentence that is
// It is not permitted to request an oracle label from a sentence that is
// in a final state.
// in a final state.
return
-
1
;
return
{
-
1
}
;
}
else
{
}
else
{
return
transition_system_
->
GetNextGoldAction
(
*
(
state
->
parser_state
()));
// TODO(googleuser): This should use the 'ParserAction' typedef.
std
::
vector
<
int
>
golds
;
transition_system_
->
GetAllNextGoldActions
(
*
(
state
->
parser_state
()),
&
golds
);
return
golds
;
}
}
}
}
...
...
research/syntaxnet/dragnn/components/syntaxnet/syntaxnet_component.h
View file @
4364390a
...
@@ -13,8 +13,8 @@
...
@@ -13,8 +13,8 @@
// limitations under the License.
// limitations under the License.
// =============================================================================
// =============================================================================
#ifndef
NLP_SAFT_OPENSOURCE_
DRAGNN_COMPONENTS_SYNTAXNET_SYNTAXNET_COMPONENT_H_
#ifndef DRAGNN_COMPONENTS_SYNTAXNET_SYNTAXNET_COMPONENT_H_
#define
NLP_SAFT_OPENSOURCE_
DRAGNN_COMPONENTS_SYNTAXNET_SYNTAXNET_COMPONENT_H_
#define DRAGNN_COMPONENTS_SYNTAXNET_SYNTAXNET_COMPONENT_H_
#include <vector>
#include <vector>
...
@@ -81,9 +81,10 @@ class SyntaxNetComponent : public Component {
...
@@ -81,9 +81,10 @@ class SyntaxNetComponent : public Component {
std
::
function
<
int
(
int
,
int
,
int
)
>
GetStepLookupFunction
(
std
::
function
<
int
(
int
,
int
,
int
)
>
GetStepLookupFunction
(
const
string
&
method
)
override
;
const
string
&
method
)
override
;
// Advances this component from the given transition matrix.
// Advances this component from the given transition matrix.Returns false
void
AdvanceFromPrediction
(
const
float
transition_matrix
[],
// if the component could not be advanced.
int
transition_matrix_length
)
override
;
bool
AdvanceFromPrediction
(
const
float
*
transition_matrix
,
int
num_items
,
int
num_actions
)
override
;
// Advances this component from the state oracles.
// Advances this component from the state oracles.
void
AdvanceFromOracle
()
override
;
void
AdvanceFromOracle
()
override
;
...
@@ -105,6 +106,13 @@ class SyntaxNetComponent : public Component {
...
@@ -105,6 +106,13 @@ class SyntaxNetComponent : public Component {
// component via the oracle until it is terminal.
// component via the oracle until it is terminal.
int
BulkGetFixedFeatures
(
const
BulkFeatureExtractor
&
extractor
)
override
;
int
BulkGetFixedFeatures
(
const
BulkFeatureExtractor
&
extractor
)
override
;
void
BulkEmbedFixedFeatures
(
int
batch_size_padding
,
int
num_steps_padding
,
int
output_array_size
,
const
vector
<
const
float
*>
&
per_channel_embeddings
,
float
*
embedding_matrix
)
override
{
LOG
(
FATAL
)
<<
"Method not supported"
;
}
// Extracts and returns the vector of LinkFeatures for the specified
// Extracts and returns the vector of LinkFeatures for the specified
// channel. Note: these are NOT translated.
// channel. Note: these are NOT translated.
std
::
vector
<
LinkFeatures
>
GetRawLinkFeatures
(
int
channel_id
)
const
override
;
std
::
vector
<
LinkFeatures
>
GetRawLinkFeatures
(
int
channel_id
)
const
override
;
...
@@ -145,13 +153,13 @@ class SyntaxNetComponent : public Component {
...
@@ -145,13 +153,13 @@ class SyntaxNetComponent : public Component {
bool
IsFinal
(
SyntaxNetTransitionState
*
state
)
const
;
bool
IsFinal
(
SyntaxNetTransitionState
*
state
)
const
;
// Oracle function for this component.
// Oracle function for this component.
int
GetOracle
Label
(
SyntaxNetTransitionState
*
state
)
const
;
std
::
vector
<
int
>
GetOracle
Vector
(
SyntaxNetTransitionState
*
state
)
const
;
// State advance function for this component.
// State advance function for this component.
void
Advance
(
SyntaxNetTransitionState
*
state
,
int
action
,
void
Advance
(
SyntaxNetTransitionState
*
state
,
int
action
,
Beam
<
SyntaxNetTransitionState
>
*
beam
);
Beam
<
SyntaxNetTransitionState
>
*
beam
);
// Creates a new state for the given
nlp_saft::SentenceE
xample.
// Creates a new state for the given
e
xample.
std
::
unique_ptr
<
SyntaxNetTransitionState
>
CreateState
(
std
::
unique_ptr
<
SyntaxNetTransitionState
>
CreateState
(
SyntaxNetSentence
*
example
);
SyntaxNetSentence
*
example
);
...
@@ -195,4 +203,4 @@ class SyntaxNetComponent : public Component {
...
@@ -195,4 +203,4 @@ class SyntaxNetComponent : public Component {
}
// namespace dragnn
}
// namespace dragnn
}
// namespace syntaxnet
}
// namespace syntaxnet
#endif //
NLP_SAFT_OPENSOURCE_
DRAGNN_COMPONENTS_SYNTAXNET_SYNTAXNET_COMPONENT_H_
#endif // DRAGNN_COMPONENTS_SYNTAXNET_SYNTAXNET_COMPONENT_H_
research/syntaxnet/dragnn/components/syntaxnet/syntaxnet_component_test.cc
View file @
4364390a
...
@@ -15,6 +15,8 @@
...
@@ -15,6 +15,8 @@
#include "dragnn/components/syntaxnet/syntaxnet_component.h"
#include "dragnn/components/syntaxnet/syntaxnet_component.h"
#include <limits>
#include "dragnn/core/input_batch_cache.h"
#include "dragnn/core/input_batch_cache.h"
#include "dragnn/core/test/generic.h"
#include "dragnn/core/test/generic.h"
#include "dragnn/core/test/mock_transition_state.h"
#include "dragnn/core/test/mock_transition_state.h"
...
@@ -197,8 +199,8 @@ TEST_F(SyntaxNetComponentTest, AdvancesFromPredictionAndTerminates) {
...
@@ -197,8 +199,8 @@ TEST_F(SyntaxNetComponentTest, AdvancesFromPredictionAndTerminates) {
// Transition the expected number of times.
// Transition the expected number of times.
for
(
int
i
=
0
;
i
<
kExpectedNumTransitions
;
++
i
)
{
for
(
int
i
=
0
;
i
<
kExpectedNumTransitions
;
++
i
)
{
EXPECT_FALSE
(
test_parser
->
IsTerminal
());
EXPECT_FALSE
(
test_parser
->
IsTerminal
());
test_parser
->
AdvanceFromPrediction
(
transition_matrix
,
EXPECT_TRUE
(
test_parser
->
AdvanceFromPrediction
(
transition_matrix
,
kBeamSize
,
kNumPossibleTransitions
*
kBeamSize
);
kNumPossibleTransitions
)
);
}
}
// At this point, the test parser should be terminal.
// At this point, the test parser should be terminal.
...
@@ -225,6 +227,29 @@ TEST_F(SyntaxNetComponentTest, AdvancesFromPredictionAndTerminates) {
...
@@ -225,6 +227,29 @@ TEST_F(SyntaxNetComponentTest, AdvancesFromPredictionAndTerminates) {
// TODO(googleuser): What should the finalized data look like?
// TODO(googleuser): What should the finalized data look like?
}
}
TEST_F
(
SyntaxNetComponentTest
,
AdvancesFromPredictionFailsWithNanWeights
)
{
// Create an empty input batch and beam vector to initialize the parser.
Sentence
sentence_0
;
TextFormat
::
ParseFromString
(
kSentence0
,
&
sentence_0
);
string
sentence_0_str
;
sentence_0
.
SerializeToString
(
&
sentence_0_str
);
auto
test_parser
=
CreateParser
({},
{
sentence_0_str
});
// There are 93 possible transitions for any given state. Create a transition
// array with a score of 10.0 for each transition.
constexpr
int
kBeamSize
=
2
;
constexpr
int
kNumPossibleTransitions
=
93
;
float
transition_matrix
[
kNumPossibleTransitions
*
kBeamSize
];
for
(
int
i
=
0
;
i
<
kNumPossibleTransitions
*
kBeamSize
;
++
i
)
{
transition_matrix
[
i
]
=
std
::
numeric_limits
<
double
>::
quiet_NaN
();
}
EXPECT_FALSE
(
test_parser
->
IsTerminal
());
EXPECT_FALSE
(
test_parser
->
AdvanceFromPrediction
(
transition_matrix
,
kBeamSize
,
kNumPossibleTransitions
));
}
TEST_F
(
SyntaxNetComponentTest
,
RetainsPassedTransitionStateData
)
{
TEST_F
(
SyntaxNetComponentTest
,
RetainsPassedTransitionStateData
)
{
// Create and initialize the state->
// Create and initialize the state->
MockTransitionState
mock_state_one
;
MockTransitionState
mock_state_one
;
...
@@ -269,8 +294,8 @@ TEST_F(SyntaxNetComponentTest, RetainsPassedTransitionStateData) {
...
@@ -269,8 +294,8 @@ TEST_F(SyntaxNetComponentTest, RetainsPassedTransitionStateData) {
// Transition the expected number of times
// Transition the expected number of times
for
(
int
i
=
0
;
i
<
kExpectedNumTransitions
;
++
i
)
{
for
(
int
i
=
0
;
i
<
kExpectedNumTransitions
;
++
i
)
{
EXPECT_FALSE
(
test_parser
->
IsTerminal
());
EXPECT_FALSE
(
test_parser
->
IsTerminal
());
test_parser
->
AdvanceFromPrediction
(
transition_matrix
,
EXPECT_TRUE
(
test_parser
->
AdvanceFromPrediction
(
transition_matrix
,
kBeamSize
,
kNumPossibleTransitions
*
kBeamSize
);
kNumPossibleTransitions
)
);
}
}
// At this point, the test parser should be terminal.
// At this point, the test parser should be terminal.
...
@@ -326,8 +351,8 @@ TEST_F(SyntaxNetComponentTest, AdvancesFromPredictionForMultiSentenceBatches) {
...
@@ -326,8 +351,8 @@ TEST_F(SyntaxNetComponentTest, AdvancesFromPredictionForMultiSentenceBatches) {
// Transition the expected number of times.
// Transition the expected number of times.
for
(
int
i
=
0
;
i
<
kExpectedNumTransitions
;
++
i
)
{
for
(
int
i
=
0
;
i
<
kExpectedNumTransitions
;
++
i
)
{
EXPECT_FALSE
(
test_parser
->
IsTerminal
());
EXPECT_FALSE
(
test_parser
->
IsTerminal
());
test_parser
->
AdvanceFromPrediction
(
EXPECT_TRUE
(
test_parser
->
AdvanceFromPrediction
(
transition_matrix
,
k
NumPossibleTransitions
*
kBeamSize
*
kBatchSize
);
transition_matrix
,
k
BeamSize
*
kBatchSize
,
kNumPossibleTransitions
)
);
}
}
// At this point, the test parser should be terminal.
// At this point, the test parser should be terminal.
...
@@ -382,8 +407,8 @@ TEST_F(SyntaxNetComponentTest,
...
@@ -382,8 +407,8 @@ TEST_F(SyntaxNetComponentTest,
constexpr
int
kExpectedNumTransitions
=
kNumTokensInLongSentence
*
2
;
constexpr
int
kExpectedNumTransitions
=
kNumTokensInLongSentence
*
2
;
for
(
int
i
=
0
;
i
<
kExpectedNumTransitions
;
++
i
)
{
for
(
int
i
=
0
;
i
<
kExpectedNumTransitions
;
++
i
)
{
EXPECT_FALSE
(
test_parser
->
IsTerminal
());
EXPECT_FALSE
(
test_parser
->
IsTerminal
());
test_parser
->
AdvanceFromPrediction
(
EXPECT_TRUE
(
test_parser
->
AdvanceFromPrediction
(
transition_matrix
,
k
NumPossibleTransitions
*
kBeamSize
*
kBatchSize
);
transition_matrix
,
k
BeamSize
*
kBatchSize
,
kNumPossibleTransitions
)
);
}
}
// At this point, the test parser should be terminal.
// At this point, the test parser should be terminal.
...
@@ -467,7 +492,7 @@ TEST_F(SyntaxNetComponentTest, ResetAllowsReductionInBatchSize) {
...
@@ -467,7 +492,7 @@ TEST_F(SyntaxNetComponentTest, ResetAllowsReductionInBatchSize) {
for
(
int
i
=
0
;
i
<
kExpectedNumTransitions
;
++
i
)
{
for
(
int
i
=
0
;
i
<
kExpectedNumTransitions
;
++
i
)
{
EXPECT_FALSE
(
parser_component
->
IsTerminal
());
EXPECT_FALSE
(
parser_component
->
IsTerminal
());
parser_component
->
AdvanceFromPrediction
(
parser_component
->
AdvanceFromPrediction
(
transition_matrix
,
k
NumPossibleTransitions
*
kBeamSize
*
kBatchSize
);
transition_matrix
,
k
BeamSize
*
kBatchSize
,
kNumPossibleTransitions
);
}
}
// At this point, the test parser should be terminal.
// At this point, the test parser should be terminal.
...
@@ -553,7 +578,7 @@ TEST_F(SyntaxNetComponentTest, ResetAllowsIncreaseInBatchSize) {
...
@@ -553,7 +578,7 @@ TEST_F(SyntaxNetComponentTest, ResetAllowsIncreaseInBatchSize) {
for
(
int
i
=
0
;
i
<
kExpectedNumTransitions
;
++
i
)
{
for
(
int
i
=
0
;
i
<
kExpectedNumTransitions
;
++
i
)
{
EXPECT_FALSE
(
parser_component
->
IsTerminal
());
EXPECT_FALSE
(
parser_component
->
IsTerminal
());
parser_component
->
AdvanceFromPrediction
(
parser_component
->
AdvanceFromPrediction
(
transition_matrix
,
k
NumPossibleTransitions
*
kBeamSize
*
kBatchSize
);
transition_matrix
,
k
BeamSize
*
kBatchSize
,
kNumPossibleTransitions
);
}
}
// At this point, the test parser should be terminal.
// At this point, the test parser should be terminal.
...
@@ -611,8 +636,8 @@ TEST_F(SyntaxNetComponentTest, ResetCausesBeamToReset) {
...
@@ -611,8 +636,8 @@ TEST_F(SyntaxNetComponentTest, ResetCausesBeamToReset) {
// Transition the expected number of times.
// Transition the expected number of times.
for
(
int
i
=
0
;
i
<
kExpectedNumTransitions
;
++
i
)
{
for
(
int
i
=
0
;
i
<
kExpectedNumTransitions
;
++
i
)
{
EXPECT_FALSE
(
test_parser
->
IsTerminal
());
EXPECT_FALSE
(
test_parser
->
IsTerminal
());
test_parser
->
AdvanceFromPrediction
(
transition_matrix
,
EXPECT_TRUE
(
test_parser
->
AdvanceFromPrediction
(
transition_matrix
,
kBeamSize
,
kNumPossibleTransitions
*
kBeamSize
);
kNumPossibleTransitions
)
);
}
}
// At this point, the test parser should be terminal.
// At this point, the test parser should be terminal.
...
@@ -823,10 +848,10 @@ TEST_F(SyntaxNetComponentTest, ExportsFixedFeatures) {
...
@@ -823,10 +848,10 @@ TEST_F(SyntaxNetComponentTest, ExportsFixedFeatures) {
}
}
// Advance twice, so that the underlying parser fills the beam.
// Advance twice, so that the underlying parser fills the beam.
test_parser
->
AdvanceFromPrediction
(
EXPECT_TRUE
(
test_parser
->
AdvanceFromPrediction
(
transition_matrix
,
k
NumPossibleTransitions
*
kBeamSize
*
kBatchSize
);
transition_matrix
,
k
BeamSize
*
kBatchSize
,
kNumPossibleTransitions
)
);
test_parser
->
AdvanceFromPrediction
(
EXPECT_TRUE
(
test_parser
->
AdvanceFromPrediction
(
transition_matrix
,
k
NumPossibleTransitions
*
kBeamSize
*
kBatchSize
);
transition_matrix
,
k
BeamSize
*
kBatchSize
,
kNumPossibleTransitions
)
);
// Get and check the raw link features.
// Get and check the raw link features.
vector
<
int32
>
indices
;
vector
<
int32
>
indices
;
...
@@ -907,10 +932,10 @@ TEST_F(SyntaxNetComponentTest, AdvancesAccordingToHighestWeightedInputOption) {
...
@@ -907,10 +932,10 @@ TEST_F(SyntaxNetComponentTest, AdvancesAccordingToHighestWeightedInputOption) {
transition_matrix
[
kBatchOffset
+
5
]
=
2
*
kTransitionValue
;
transition_matrix
[
kBatchOffset
+
5
]
=
2
*
kTransitionValue
;
// Advance twice, so that the underlying parser fills the beam.
// Advance twice, so that the underlying parser fills the beam.
test_parser
->
AdvanceFromPrediction
(
EXPECT_TRUE
(
test_parser
->
AdvanceFromPrediction
(
transition_matrix
,
k
NumPossibleTransitions
*
kBeamSize
*
kBatchSize
);
transition_matrix
,
k
BeamSize
*
kBatchSize
,
kNumPossibleTransitions
)
);
test_parser
->
AdvanceFromPrediction
(
EXPECT_TRUE
(
test_parser
->
AdvanceFromPrediction
(
transition_matrix
,
k
NumPossibleTransitions
*
kBeamSize
*
kBatchSize
);
transition_matrix
,
k
BeamSize
*
kBatchSize
,
kNumPossibleTransitions
)
);
// Get and check the raw link features.
// Get and check the raw link features.
vector
<
int32
>
indices
;
vector
<
int32
>
indices
;
...
@@ -1112,10 +1137,10 @@ TEST_F(SyntaxNetComponentTest, ExportsRawLinkFeatures) {
...
@@ -1112,10 +1137,10 @@ TEST_F(SyntaxNetComponentTest, ExportsRawLinkFeatures) {
}
}
// Advance twice, so that the underlying parser fills the beam.
// Advance twice, so that the underlying parser fills the beam.
test_parser
->
AdvanceFromPrediction
(
EXPECT_TRUE
(
test_parser
->
AdvanceFromPrediction
(
transition_matrix
,
k
NumPossibleTransitions
*
kBeamSize
*
kBatchSize
);
transition_matrix
,
k
BeamSize
*
kBatchSize
,
kNumPossibleTransitions
)
);
test_parser
->
AdvanceFromPrediction
(
EXPECT_TRUE
(
test_parser
->
AdvanceFromPrediction
(
transition_matrix
,
k
NumPossibleTransitions
*
kBeamSize
*
kBatchSize
);
transition_matrix
,
k
BeamSize
*
kBatchSize
,
kNumPossibleTransitions
)
);
// Get and check the raw link features.
// Get and check the raw link features.
constexpr
int
kNumLinkFeatures
=
2
;
constexpr
int
kNumLinkFeatures
=
2
;
...
@@ -1269,5 +1294,21 @@ TEST_F(SyntaxNetComponentTest, TracingOutputsFeatureNames) {
...
@@ -1269,5 +1294,21 @@ TEST_F(SyntaxNetComponentTest, TracingOutputsFeatureNames) {
EXPECT_EQ
(
link_features
.
at
(
1
).
feature_name
(),
"stack(1).focus"
);
EXPECT_EQ
(
link_features
.
at
(
1
).
feature_name
(),
"stack(1).focus"
);
}
}
TEST_F
(
SyntaxNetComponentTest
,
BulkEmbedFixedFeaturesIsNotSupported
)
{
// Create an empty input batch and beam vector to initialize the parser.
Sentence
sentence_0
;
// TODO(googleuser): Wrap this in a lint-friendly helper function.
TextFormat
::
ParseFromString
(
kSentence0
,
&
sentence_0
);
string
sentence_0_str
;
sentence_0
.
SerializeToString
(
&
sentence_0_str
);
constexpr
int
kBeamSize
=
1
;
auto
test_parser
=
CreateParserWithBeamSize
(
kBeamSize
,
{},
{
sentence_0_str
});
EXPECT_TRUE
(
test_parser
->
IsReady
());
EXPECT_DEATH
(
test_parser
->
BulkEmbedFixedFeatures
(
0
,
0
,
0
,
{
nullptr
},
nullptr
),
"Method not supported"
);
}
}
// namespace dragnn
}
// namespace dragnn
}
// namespace syntaxnet
}
// namespace syntaxnet
research/syntaxnet/dragnn/components/syntaxnet/syntaxnet_link_feature_extractor.h
View file @
4364390a
...
@@ -13,8 +13,8 @@
...
@@ -13,8 +13,8 @@
// limitations under the License.
// limitations under the License.
// =============================================================================
// =============================================================================
#ifndef
NLP_SAFT_OPENSOURCE_
DRAGNN_COMPONENTS_SYNTAXNET_SYNTAXNET_LINK_FEATURE_EXTRACTOR_H_
#ifndef DRAGNN_COMPONENTS_SYNTAXNET_SYNTAXNET_LINK_FEATURE_EXTRACTOR_H_
#define
NLP_SAFT_OPENSOURCE_
DRAGNN_COMPONENTS_SYNTAXNET_SYNTAXNET_LINK_FEATURE_EXTRACTOR_H_
#define DRAGNN_COMPONENTS_SYNTAXNET_SYNTAXNET_LINK_FEATURE_EXTRACTOR_H_
#include <string>
#include <string>
#include <vector>
#include <vector>
...
@@ -29,12 +29,8 @@ namespace syntaxnet {
...
@@ -29,12 +29,8 @@ namespace syntaxnet {
namespace
dragnn
{
namespace
dragnn
{
// Provides feature extraction for linked features in the
// Provides feature extraction for linked features in the
// WrapperParserComponent. This re-ues the EmbeddingFeatureExtractor
// WrapperParserComponent. This re-uses the EmbeddingFeatureExtractor
// architecture to get another set of feature extractors. Note that we should
// architecture to get another set of feature extractors.
// ignore predicate maps here, and we don't care about the vocabulary size
// because all the feature values will be used for translation, but this means
// we can configure the extractor from the GCL using the standard
// neurosis-lib.wf syntax.
//
//
// Because it uses a different prefix, it can be executed in the same wf.stage
// Because it uses a different prefix, it can be executed in the same wf.stage
// as the regular fixed extractor.
// as the regular fixed extractor.
...
@@ -67,4 +63,4 @@ class SyntaxNetLinkFeatureExtractor : public ParserEmbeddingFeatureExtractor {
...
@@ -67,4 +63,4 @@ class SyntaxNetLinkFeatureExtractor : public ParserEmbeddingFeatureExtractor {
}
// namespace dragnn
}
// namespace dragnn
}
// namespace syntaxnet
}
// namespace syntaxnet
#endif //
NLP_SAFT_OPENSOURCE_
DRAGNN_COMPONENTS_SYNTAXNET_SYNTAXNET_LINK_FEATURE_EXTRACTOR_H_
#endif // DRAGNN_COMPONENTS_SYNTAXNET_SYNTAXNET_LINK_FEATURE_EXTRACTOR_H_
research/syntaxnet/dragnn/components/syntaxnet/syntaxnet_link_feature_extractor_test.cc
View file @
4364390a
...
@@ -34,7 +34,7 @@ class ExportSpecTest : public ::testing::Test {
...
@@ -34,7 +34,7 @@ class ExportSpecTest : public ::testing::Test {
TEST_F
(
ExportSpecTest
,
WritesChannelSpec
)
{
TEST_F
(
ExportSpecTest
,
WritesChannelSpec
)
{
TaskContext
context
;
TaskContext
context
;
context
.
SetParameter
(
"neurosis_feature_syntax_version"
,
"2"
);
context
.
SetParameter
(
"link_features"
,
"input.focus;stack.focus"
);
context
.
SetParameter
(
"link_features"
,
"input.focus;stack.focus"
);
context
.
SetParameter
(
"link_embedding_names"
,
"tagger;parser"
);
context
.
SetParameter
(
"link_embedding_names"
,
"tagger;parser"
);
context
.
SetParameter
(
"link_predicate_maps"
,
"none;none"
);
context
.
SetParameter
(
"link_predicate_maps"
,
"none;none"
);
...
...
research/syntaxnet/dragnn/components/syntaxnet/syntaxnet_transition_state.cc
View file @
4364390a
...
@@ -23,7 +23,9 @@ namespace dragnn {
...
@@ -23,7 +23,9 @@ namespace dragnn {
SyntaxNetTransitionState
::
SyntaxNetTransitionState
(
SyntaxNetTransitionState
::
SyntaxNetTransitionState
(
std
::
unique_ptr
<
ParserState
>
parser_state
,
SyntaxNetSentence
*
sentence
)
std
::
unique_ptr
<
ParserState
>
parser_state
,
SyntaxNetSentence
*
sentence
)
:
parser_state_
(
std
::
move
(
parser_state
)),
sentence_
(
sentence
)
{
:
parser_state_
(
std
::
move
(
parser_state
)),
sentence_
(
sentence
),
is_gold_
(
false
)
{
score_
=
0
;
score_
=
0
;
current_beam_index_
=
-
1
;
current_beam_index_
=
-
1
;
parent_beam_index_
=
0
;
parent_beam_index_
=
0
;
...
@@ -60,21 +62,25 @@ std::unique_ptr<SyntaxNetTransitionState> SyntaxNetTransitionState::Clone()
...
@@ -60,21 +62,25 @@ std::unique_ptr<SyntaxNetTransitionState> SyntaxNetTransitionState::Clone()
return
new_state
;
return
new_state
;
}
}
const
int
SyntaxNetTransitionState
::
ParentBeamIndex
()
const
{
int
SyntaxNetTransitionState
::
ParentBeamIndex
()
const
{
return
parent_beam_index_
;
return
parent_beam_index_
;
}
}
const
int
SyntaxNetTransitionState
::
GetBeamIndex
()
const
{
int
SyntaxNetTransitionState
::
GetBeamIndex
()
const
{
return
current_beam_index_
;
return
current_beam_index_
;
}
}
void
SyntaxNetTransitionState
::
SetBeamIndex
(
const
int
index
)
{
bool
SyntaxNetTransitionState
::
IsGold
()
const
{
return
is_gold_
;
}
void
SyntaxNetTransitionState
::
SetGold
(
bool
is_gold
)
{
is_gold_
=
is_gold
;
}
void
SyntaxNetTransitionState
::
SetBeamIndex
(
int
index
)
{
current_beam_index_
=
index
;
current_beam_index_
=
index
;
}
}
const
float
SyntaxNetTransitionState
::
GetScore
()
const
{
return
score_
;
}
float
SyntaxNetTransitionState
::
GetScore
()
const
{
return
score_
;
}
void
SyntaxNetTransitionState
::
SetScore
(
const
float
score
)
{
score_
=
score
;
}
void
SyntaxNetTransitionState
::
SetScore
(
float
score
)
{
score_
=
score
;
}
string
SyntaxNetTransitionState
::
HTMLRepresentation
()
const
{
string
SyntaxNetTransitionState
::
HTMLRepresentation
()
const
{
// Crude HTML string showing the stack and the word on the input.
// Crude HTML string showing the stack and the word on the input.
...
...
research/syntaxnet/dragnn/components/syntaxnet/syntaxnet_transition_state.h
View file @
4364390a
...
@@ -13,8 +13,8 @@
...
@@ -13,8 +13,8 @@
// limitations under the License.
// limitations under the License.
// =============================================================================
// =============================================================================
#ifndef
NLP_SAFT_OPENSOURCE_
DRAGNN_COMPONENTS_SYNTAXNET_SYNTAXNET_TRANSITION_STATE_H_
#ifndef DRAGNN_COMPONENTS_SYNTAXNET_SYNTAXNET_TRANSITION_STATE_H_
#define
NLP_SAFT_OPENSOURCE_
DRAGNN_COMPONENTS_SYNTAXNET_SYNTAXNET_TRANSITION_STATE_H_
#define DRAGNN_COMPONENTS_SYNTAXNET_SYNTAXNET_TRANSITION_STATE_H_
#include <vector>
#include <vector>
...
@@ -31,11 +31,11 @@ namespace dragnn {
...
@@ -31,11 +31,11 @@ namespace dragnn {
class
SyntaxNetTransitionState
class
SyntaxNetTransitionState
:
public
CloneableTransitionState
<
SyntaxNetTransitionState
>
{
:
public
CloneableTransitionState
<
SyntaxNetTransitionState
>
{
public:
public:
// Create a SyntaxNetTransitionState to wrap this
nlp_saft::
ParserState.
// Create
s
a SyntaxNetTransitionState to wrap this ParserState.
SyntaxNetTransitionState
(
std
::
unique_ptr
<
ParserState
>
parser_state
,
SyntaxNetTransitionState
(
std
::
unique_ptr
<
ParserState
>
parser_state
,
SyntaxNetSentence
*
sentence
);
SyntaxNetSentence
*
sentence
);
// Initialize this TransitionState from a previous TransitionState. The
// Initialize
s
this TransitionState from a previous TransitionState. The
// ParentBeamIndex is the location of that previous TransitionState in the
// ParentBeamIndex is the location of that previous TransitionState in the
// provided beam.
// provided beam.
void
Init
(
const
TransitionState
&
parent
)
override
;
void
Init
(
const
TransitionState
&
parent
)
override
;
...
@@ -43,21 +43,27 @@ class SyntaxNetTransitionState
...
@@ -43,21 +43,27 @@ class SyntaxNetTransitionState
// Produces a new state with the same backing data as this state.
// Produces a new state with the same backing data as this state.
std
::
unique_ptr
<
SyntaxNetTransitionState
>
Clone
()
const
override
;
std
::
unique_ptr
<
SyntaxNetTransitionState
>
Clone
()
const
override
;
// Return the beam index of the state passed into the initializer of this
// Return
s
the beam index of the state passed into the initializer of this
// TransitionState.
// TransitionState.
const
int
ParentBeamIndex
()
const
override
;
int
ParentBeamIndex
()
const
override
;
// Get the current beam index for this state.
// Get
s
the current beam index for this state.
const
int
GetBeamIndex
()
const
override
;
int
GetBeamIndex
()
const
override
;
// Set the current beam index for this state.
// Set
s
the current beam index for this state.
void
SetBeamIndex
(
const
int
index
)
override
;
void
SetBeamIndex
(
int
index
)
override
;
// Get the score associated with this transition state.
// Get
s
the score associated with this transition state.
const
float
GetScore
()
const
override
;
float
GetScore
()
const
override
;
// Set the score associated with this transition state.
// Sets the score associated with this transition state.
void
SetScore
(
const
float
score
)
override
;
void
SetScore
(
float
score
)
override
;
// Gets the state's gold-ness (if it is on or consistent with the oracle path)
bool
IsGold
()
const
override
;
// Sets the gold-ness of this state.
void
SetGold
(
bool
is_gold
)
override
;
// Depicts this state as an HTML-language string.
// Depicts this state as an HTML-language string.
string
HTMLRepresentation
()
const
override
;
string
HTMLRepresentation
()
const
override
;
...
@@ -108,7 +114,7 @@ class SyntaxNetTransitionState
...
@@ -108,7 +114,7 @@ class SyntaxNetTransitionState
parent_for_token_
.
insert
(
parent_for_token_
.
begin
()
+
token
,
parent
);
parent_for_token_
.
insert
(
parent_for_token_
.
begin
()
+
token
,
parent
);
}
}
// Accessor for the underlying
nlp_saft::
ParserState.
// Accessor for the underlying ParserState.
ParserState
*
parser_state
()
{
return
parser_state_
.
get
();
}
ParserState
*
parser_state
()
{
return
parser_state_
.
get
();
}
// Accessor for the underlying sentence object.
// Accessor for the underlying sentence object.
...
@@ -151,9 +157,12 @@ class SyntaxNetTransitionState
...
@@ -151,9 +157,12 @@ class SyntaxNetTransitionState
// Trace of the history to produce this state.
// Trace of the history to produce this state.
std
::
unique_ptr
<
ComponentTrace
>
trace_
;
std
::
unique_ptr
<
ComponentTrace
>
trace_
;
// True if this state is gold.
bool
is_gold_
;
};
};
}
// namespace dragnn
}
// namespace dragnn
}
// namespace syntaxnet
}
// namespace syntaxnet
#endif //
NLP_SAFT_OPENSOURCE_
DRAGNN_COMPONENTS_SYNTAXNET_SYNTAXNET_TRANSITION_STATE_H_
#endif // DRAGNN_COMPONENTS_SYNTAXNET_SYNTAXNET_TRANSITION_STATE_H_
research/syntaxnet/dragnn/components/syntaxnet/syntaxnet_transition_state_test.cc
View file @
4364390a
...
@@ -134,6 +134,22 @@ TEST_F(SyntaxNetTransitionStateTest, CanSetAndGetScore) {
...
@@ -134,6 +134,22 @@ TEST_F(SyntaxNetTransitionStateTest, CanSetAndGetScore) {
EXPECT_EQ
(
test_state
->
GetScore
(),
kNewScore
);
EXPECT_EQ
(
test_state
->
GetScore
(),
kNewScore
);
}
}
// Validates the consistency of the goldness setter and getter.
TEST_F
(
SyntaxNetTransitionStateTest
,
CanSetAndGetGold
)
{
// Create and initialize a test state.
MockTransitionState
mock_state
;
auto
test_state
=
CreateState
();
test_state
->
Init
(
mock_state
);
constexpr
bool
kOldGold
=
true
;
test_state
->
SetGold
(
kOldGold
);
EXPECT_EQ
(
test_state
->
IsGold
(),
kOldGold
);
constexpr
bool
kNewGold
=
false
;
test_state
->
SetGold
(
kNewGold
);
EXPECT_EQ
(
test_state
->
IsGold
(),
kNewGold
);
}
// This test ensures that the initializing state's current index is saved
// This test ensures that the initializing state's current index is saved
// as the parent beam index of the state being initialized.
// as the parent beam index of the state being initialized.
TEST_F
(
SyntaxNetTransitionStateTest
,
ReportsParentBeamIndex
)
{
TEST_F
(
SyntaxNetTransitionStateTest
,
ReportsParentBeamIndex
)
{
...
...
research/syntaxnet/dragnn/components/util/bulk_feature_extractor.h
View file @
4364390a
...
@@ -13,8 +13,8 @@
...
@@ -13,8 +13,8 @@
// limitations under the License.
// limitations under the License.
// =============================================================================
// =============================================================================
#ifndef
NLP_SAFT_OPENSOURCE_
DRAGNN_COMPONENTS_UTIL_BULK_FEATURE_EXTRACTOR_H_
#ifndef DRAGNN_COMPONENTS_UTIL_BULK_FEATURE_EXTRACTOR_H_
#define
NLP_SAFT_OPENSOURCE_
DRAGNN_COMPONENTS_UTIL_BULK_FEATURE_EXTRACTOR_H_
#define DRAGNN_COMPONENTS_UTIL_BULK_FEATURE_EXTRACTOR_H_
#include <functional>
#include <functional>
#include <utility>
#include <utility>
...
@@ -107,4 +107,4 @@ class BulkFeatureExtractor {
...
@@ -107,4 +107,4 @@ class BulkFeatureExtractor {
}
// namespace dragnn
}
// namespace dragnn
}
// namespace syntaxnet
}
// namespace syntaxnet
#endif //
NLP_SAFT_OPENSOURCE_
DRAGNN_COMPONENTS_UTIL_BULK_FEATURE_EXTRACTOR_H_
#endif // DRAGNN_COMPONENTS_UTIL_BULK_FEATURE_EXTRACTOR_H_
research/syntaxnet/dragnn/config_builder/__init__.py
0 → 100644
View file @
4364390a
Prev
1
2
3
4
5
…
9
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment