Unverified Commit 80178fc6 authored by Mark Omernick's avatar Mark Omernick Committed by GitHub
Browse files

Merge pull request #4153 from terryykoo/master

Export @195097388.
parents a84e1ef9 edea2b67
FROM ubuntu:16.10 FROM ubuntu:16.04
ENV SYNTAXNETDIR=/opt/tensorflow PATH=$PATH:/root/bin ENV SYNTAXNETDIR=/opt/tensorflow PATH=$PATH:/root/bin
...@@ -57,10 +57,10 @@ RUN python -m pip install \ ...@@ -57,10 +57,10 @@ RUN python -m pip install \
&& rm -rf /root/.cache/pip /tmp/pip* && rm -rf /root/.cache/pip /tmp/pip*
# Installs Bazel. # Installs Bazel.
RUN wget --quiet https://github.com/bazelbuild/bazel/releases/download/0.8.1/bazel-0.8.1-installer-linux-x86_64.sh \ RUN wget --quiet https://github.com/bazelbuild/bazel/releases/download/0.11.1/bazel-0.11.1-installer-linux-x86_64.sh \
&& chmod +x bazel-0.8.1-installer-linux-x86_64.sh \ && chmod +x bazel-0.11.1-installer-linux-x86_64.sh \
&& ./bazel-0.8.1-installer-linux-x86_64.sh \ && ./bazel-0.11.1-installer-linux-x86_64.sh \
&& rm ./bazel-0.8.1-installer-linux-x86_64.sh && rm ./bazel-0.11.1-installer-linux-x86_64.sh
COPY WORKSPACE $SYNTAXNETDIR/syntaxnet/WORKSPACE COPY WORKSPACE $SYNTAXNETDIR/syntaxnet/WORKSPACE
COPY tools/bazel.rc $SYNTAXNETDIR/syntaxnet/tools/bazel.rc COPY tools/bazel.rc $SYNTAXNETDIR/syntaxnet/tools/bazel.rc
......
...@@ -60,10 +60,10 @@ The simplest way to get started with DRAGNN is by loading our Docker container. ...@@ -60,10 +60,10 @@ The simplest way to get started with DRAGNN is by loading our Docker container.
[Here](g3doc/CLOUD.md) is a tutorial for running the DRAGNN container on [Here](g3doc/CLOUD.md) is a tutorial for running the DRAGNN container on
[GCP](https://cloud.google.com) (just as applicable to your own computer). [GCP](https://cloud.google.com) (just as applicable to your own computer).
### Ubuntu 16.10+ binary installation ### Ubuntu 16.04+ binary installation
_This process takes ~5 minutes, but is only compatible with Linux using GNU libc _This process takes ~5 minutes, but is only compatible with Linux using GNU libc
3.4.22 and above (e.g. Ubuntu 16.10)._ 3.4.22 and above (e.g. Ubuntu 16.04)._
Binary wheel packages are provided for TensorFlow and SyntaxNet. If you do not Binary wheel packages are provided for TensorFlow and SyntaxNet. If you do not
need to write new binary TensorFlow ops, these should suffice. need to write new binary TensorFlow ops, these should suffice.
...@@ -92,9 +92,9 @@ source. You'll need to install: ...@@ -92,9 +92,9 @@ source. You'll need to install:
* python 2.7: * python 2.7:
* Python 3 support is not available yet * Python 3 support is not available yet
* bazel 0.5.4: * bazel 0.11.1:
* Follow the instructions [here](http://bazel.build/docs/install.html) * Follow the instructions [here](http://bazel.build/docs/install.html)
* Alternately, Download bazel 0.5.4 <.deb> from * Alternately, Download bazel 0.11.1 <.deb> from
[https://github.com/bazelbuild/bazel/releases](https://github.com/bazelbuild/bazel/releases) [https://github.com/bazelbuild/bazel/releases](https://github.com/bazelbuild/bazel/releases)
for your system configuration. for your system configuration.
* Install it using the command: sudo dpkg -i <.deb file> * Install it using the command: sudo dpkg -i <.deb file>
...@@ -105,14 +105,14 @@ source. You'll need to install: ...@@ -105,14 +105,14 @@ source. You'll need to install:
* protocol buffers, with a version supported by TensorFlow: * protocol buffers, with a version supported by TensorFlow:
* check your protobuf version with `pip freeze | grep protobuf` * check your protobuf version with `pip freeze | grep protobuf`
* upgrade to a supported version with `pip install -U protobuf==3.3.0` * upgrade to a supported version with `pip install -U protobuf==3.3.0`
* autograd, with a version supported by TensorFlow:
* `pip install -U autograd==1.1.13`
* mock, the testing package: * mock, the testing package:
* `pip install mock` * `pip install mock`
* asciitree, to draw parse trees on the console for the demo: * asciitree, to draw parse trees on the console for the demo:
* `pip install asciitree` * `pip install asciitree`
* numpy, package for scientific computing: * numpy, package for scientific computing:
* `pip install numpy` * `pip install numpy`
* autograd 1.1.13, for automatic differentiation (not yet compatible with autograd v1.2 rewrite):
* `pip install autograd==1.1.13`
* pygraphviz to visualize traces and parse trees: * pygraphviz to visualize traces and parse trees:
* `apt-get install -y graphviz libgraphviz-dev` * `apt-get install -y graphviz libgraphviz-dev`
* `pip install pygraphviz * `pip install pygraphviz
......
local_repository( local_repository(
name = "org_tensorflow", name = "org_tensorflow",
path = "tensorflow", path = "tensorflow",
) )
# We need to pull in @io_bazel_rules_closure for TensorFlow. Bazel design # We need to pull in @io_bazel_rules_closure for TensorFlow. Bazel design
...@@ -9,22 +9,33 @@ local_repository( ...@@ -9,22 +9,33 @@ local_repository(
# @io_bazel_rules_closure. # @io_bazel_rules_closure.
http_archive( http_archive(
name = "io_bazel_rules_closure", name = "io_bazel_rules_closure",
sha256 = "25f5399f18d8bf9ce435f85c6bbf671ec4820bc4396b3022cc5dc4bc66303609", sha256 = "6691c58a2cd30a86776dd9bb34898b041e37136f2dc7e24cadaeaf599c95c657",
strip_prefix = "rules_closure-0.4.2", strip_prefix = "rules_closure-08039ba8ca59f64248bb3b6ae016460fe9c9914f",
urls = [ urls = [
"http://bazel-mirror.storage.googleapis.com/github.com/bazelbuild/rules_closure/archive/0.4.2.tar.gz", # 2017-08-30 "http://bazel-mirror.storage.googleapis.com/github.com/bazelbuild/rules_closure/archive/08039ba8ca59f64248bb3b6ae016460fe9c9914f.tar.gz",
"https://github.com/bazelbuild/rules_closure/archive/0.4.2.tar.gz", "https://github.com/bazelbuild/rules_closure/archive/08039ba8ca59f64248bb3b6ae016460fe9c9914f.tar.gz",
], ],
) )
load("@org_tensorflow//tensorflow:workspace.bzl", "tf_workspace") load("@org_tensorflow//tensorflow:workspace.bzl", "tf_workspace")
tf_workspace(path_prefix="", tf_repo_name="org_tensorflow")
# Test that Bazel is up-to-date. tf_workspace(
load("@org_tensorflow//tensorflow:workspace.bzl", "check_version") path_prefix = "",
check_version("0.4.2") tf_repo_name = "org_tensorflow",
)
http_archive(
name = "sling",
sha256 = "f1ce597476cb024808ca0a371a01db9dda4e0c58fb34a4f9c4ea91796f437b10",
strip_prefix = "sling-e3ae9d94eb1d9ee037a851070d54ed2eefaa928a",
urls = [
"http://bazel-mirror.storage.googleapis.com/github.com/google/sling/archive/e3ae9d94eb1d9ee037a851070d54ed2eefaa928a.tar.gz",
"https://github.com/google/sling/archive/e3ae9d94eb1d9ee037a851070d54ed2eefaa928a.tar.gz",
],
)
# Used by SLING.
bind( bind(
name = "protobuf", name = "zlib",
actual = "@protobuf_archive//:protobuf", actual = "@zlib_archive//:zlib",
) )
...@@ -9,3 +9,4 @@ COPY dragnn $SYNTAXNETDIR/syntaxnet/dragnn ...@@ -9,3 +9,4 @@ COPY dragnn $SYNTAXNETDIR/syntaxnet/dragnn
COPY syntaxnet $SYNTAXNETDIR/syntaxnet/syntaxnet COPY syntaxnet $SYNTAXNETDIR/syntaxnet/syntaxnet
COPY third_party $SYNTAXNETDIR/syntaxnet/third_party COPY third_party $SYNTAXNETDIR/syntaxnet/third_party
COPY util/utf8 $SYNTAXNETDIR/syntaxnet/util/utf8 COPY util/utf8 $SYNTAXNETDIR/syntaxnet/util/utf8
COPY WORKSPACE $SYNTAXNETDIR/syntaxnet/WORKSPACE
FROM ubuntu:16.10 FROM ubuntu:16.04
ENV SYNTAXNETDIR=/opt/tensorflow PATH=$PATH:/root/bin ENV SYNTAXNETDIR=/opt/tensorflow PATH=$PATH:/root/bin
...@@ -57,10 +57,10 @@ RUN python -m pip install \ ...@@ -57,10 +57,10 @@ RUN python -m pip install \
&& rm -rf /root/.cache/pip /tmp/pip* && rm -rf /root/.cache/pip /tmp/pip*
# Installs Bazel. # Installs Bazel.
RUN wget --quiet https://github.com/bazelbuild/bazel/releases/download/0.5.3/bazel-0.5.3-installer-linux-x86_64.sh \ RUN wget --quiet https://github.com/bazelbuild/bazel/releases/download/0.11.1/bazel-0.11.1-installer-linux-x86_64.sh \
&& chmod +x bazel-0.5.3-installer-linux-x86_64.sh \ && chmod +x bazel-0.11.1-installer-linux-x86_64.sh \
&& JAVA_HOME=/usr/lib/jvm/java-8-openjdk-amd64/ ./bazel-0.5.3-installer-linux-x86_64.sh \ && JAVA_HOME=/usr/lib/jvm/java-8-openjdk-amd64/ ./bazel-0.11.1-installer-linux-x86_64.sh \
&& rm ./bazel-0.5.3-installer-linux-x86_64.sh && rm ./bazel-0.11.1-installer-linux-x86_64.sh
COPY WORKSPACE $SYNTAXNETDIR/syntaxnet/WORKSPACE COPY WORKSPACE $SYNTAXNETDIR/syntaxnet/WORKSPACE
COPY tools/bazel.rc $SYNTAXNETDIR/syntaxnet/tools/bazel.rc COPY tools/bazel.rc $SYNTAXNETDIR/syntaxnet/tools/bazel.rc
...@@ -69,12 +69,9 @@ COPY tools/bazel.rc $SYNTAXNETDIR/syntaxnet/tools/bazel.rc ...@@ -69,12 +69,9 @@ COPY tools/bazel.rc $SYNTAXNETDIR/syntaxnet/tools/bazel.rc
# source. This makes it more convenient to re-compile DRAGNN / SyntaxNet for # source. This makes it more convenient to re-compile DRAGNN / SyntaxNet for
# development (though not as convenient as the docker-devel scripts). # development (though not as convenient as the docker-devel scripts).
RUN cd $SYNTAXNETDIR/syntaxnet \ RUN cd $SYNTAXNETDIR/syntaxnet \
&& git clone --branch r1.3 --recurse-submodules https://github.com/tensorflow/tensorflow \ && git clone --branch r1.8 --recurse-submodules https://github.com/tensorflow/tensorflow \
&& cd tensorflow \ && cd tensorflow \
# This line removes a bad archive target which causes Tensorflow install && tensorflow/tools/ci_build/builds/configured CPU \
# to fail.
&& sed -i '\@https://github.com/google/protobuf/archive/0b059a3d8a8f8aa40dde7bea55edca4ec5dfea66.tar.gz@d' tensorflow/workspace.bzl \
&& tensorflow/tools/ci_build/builds/configured CPU \\
&& cd $SYNTAXNETDIR/syntaxnet \ && cd $SYNTAXNETDIR/syntaxnet \
&& bazel build -c opt @org_tensorflow//tensorflow:tensorflow_py && bazel build -c opt @org_tensorflow//tensorflow:tensorflow_py
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
# #
# It might be more efficient to use a minimal distribution, like Alpine. But # It might be more efficient to use a minimal distribution, like Alpine. But
# the upside of this being popular is that people might already have it. # the upside of this being popular is that people might already have it.
FROM ubuntu:16.10 FROM ubuntu:16.04
ENV SYNTAXNETDIR=/opt/tensorflow PATH=$PATH:/root/bin ENV SYNTAXNETDIR=/opt/tensorflow PATH=$PATH:/root/bin
......
...@@ -10,7 +10,8 @@ cc_library( ...@@ -10,7 +10,8 @@ cc_library(
"//dragnn/core:component_registry", "//dragnn/core:component_registry",
"//dragnn/core/interfaces:component", "//dragnn/core/interfaces:component",
"//dragnn/core/interfaces:transition_state", "//dragnn/core/interfaces:transition_state",
"//dragnn/protos:data_proto", "//dragnn/core/util:label",
"//dragnn/protos:data_proto_cc",
"//syntaxnet:base", "//syntaxnet:base",
], ],
alwayslink = 1, alwayslink = 1,
...@@ -27,7 +28,7 @@ cc_test( ...@@ -27,7 +28,7 @@ cc_test(
"//dragnn/core/test:mock_transition_state", "//dragnn/core/test:mock_transition_state",
"//dragnn/io:sentence_input_batch", "//dragnn/io:sentence_input_batch",
"//syntaxnet:base", "//syntaxnet:base",
"//syntaxnet:sentence_proto", "//syntaxnet:sentence_proto_cc",
"//syntaxnet:test_main", "//syntaxnet:test_main",
], ],
) )
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#include "dragnn/core/component_registry.h" #include "dragnn/core/component_registry.h"
#include "dragnn/core/interfaces/component.h" #include "dragnn/core/interfaces/component.h"
#include "dragnn/core/interfaces/transition_state.h" #include "dragnn/core/interfaces/transition_state.h"
#include "dragnn/core/util/label.h"
#include "dragnn/protos/data.pb.h" #include "dragnn/protos/data.pb.h"
#include "syntaxnet/base.h" #include "syntaxnet/base.h"
...@@ -90,7 +91,8 @@ class StatelessComponent : public Component { ...@@ -90,7 +91,8 @@ class StatelessComponent : public Component {
void AdvanceFromOracle() override { void AdvanceFromOracle() override {
LOG(FATAL) << "[" << name_ << "] AdvanceFromOracle not supported"; LOG(FATAL) << "[" << name_ << "] AdvanceFromOracle not supported";
} }
std::vector<std::vector<int>> GetOracleLabels() const override { std::vector<std::vector<std::vector<Label>>> GetOracleLabels()
const override {
LOG(FATAL) << "[" << name_ << "] Method not supported"; LOG(FATAL) << "[" << name_ << "] Method not supported";
} }
int GetFixedFeatures(std::function<int32 *(int)> allocate_indices, int GetFixedFeatures(std::function<int32 *(int)> allocate_indices,
...@@ -108,7 +110,15 @@ class StatelessComponent : public Component { ...@@ -108,7 +110,15 @@ class StatelessComponent : public Component {
float *embedding_output) override { float *embedding_output) override {
LOG(FATAL) << "[" << name_ << "] Method not supported"; LOG(FATAL) << "[" << name_ << "] Method not supported";
} }
void BulkEmbedDenseFixedFeatures(
const vector<const float *> &per_channel_embeddings,
float *embedding_output, int embedding_output_size,
int32 *offset_array_output, int offset_array_size) override {
LOG(FATAL) << "[" << name_ << "] Method not supported";
}
int BulkDenseFeatureSize() const override {
LOG(FATAL) << "Method not supported";
}
std::vector<LinkFeatures> GetRawLinkFeatures(int channel_id) const override { std::vector<LinkFeatures> GetRawLinkFeatures(int channel_id) const override {
LOG(FATAL) << "[" << name_ << "] Method not supported"; LOG(FATAL) << "[" << name_ << "] Method not supported";
} }
...@@ -118,9 +128,9 @@ class StatelessComponent : public Component { ...@@ -118,9 +128,9 @@ class StatelessComponent : public Component {
} }
private: private:
string name_; // component name string name_; // component name
int batch_size_ = 1; // number of sentences in current batch int batch_size_ = 1; // number of sentences in current batch
int beam_size_ = 1; // maximum beam size int beam_size_ = 1; // maximum beam size
// Parent states passed to InitializeData(), and passed along in GetBeam(). // Parent states passed to InitializeData(), and passed along in GetBeam().
std::vector<std::vector<const TransitionState *>> parent_states_; std::vector<std::vector<const TransitionState *>> parent_states_;
......
...@@ -16,18 +16,20 @@ cc_library( ...@@ -16,18 +16,20 @@ cc_library(
"//dragnn/core:input_batch_cache", "//dragnn/core:input_batch_cache",
"//dragnn/core/interfaces:component", "//dragnn/core/interfaces:component",
"//dragnn/core/interfaces:transition_state", "//dragnn/core/interfaces:transition_state",
"//dragnn/core/util:label",
"//dragnn/io:sentence_input_batch", "//dragnn/io:sentence_input_batch",
"//dragnn/io:syntaxnet_sentence", "//dragnn/io:syntaxnet_sentence",
"//dragnn/protos:data_proto", "//dragnn/protos:data_proto_cc",
"//dragnn/protos:spec_proto", "//dragnn/protos:spec_proto_cc",
"//dragnn/protos:trace_proto", "//dragnn/protos:trace_proto_cc",
"//syntaxnet:base", "//syntaxnet:base",
"//syntaxnet:parser_transitions", "//syntaxnet:parser_transitions",
"//syntaxnet:registry", "//syntaxnet:registry",
"//syntaxnet:sparse_proto", "//syntaxnet:sparse_proto_cc",
"//syntaxnet:task_context", "//syntaxnet:task_context",
"//syntaxnet:task_spec_proto", "//syntaxnet:task_spec_proto_cc",
"//syntaxnet:utils", "//syntaxnet:utils",
"//util/utf8:unicodetext",
], ],
alwayslink = 1, alwayslink = 1,
) )
...@@ -37,7 +39,7 @@ cc_library( ...@@ -37,7 +39,7 @@ cc_library(
srcs = ["syntaxnet_link_feature_extractor.cc"], srcs = ["syntaxnet_link_feature_extractor.cc"],
hdrs = ["syntaxnet_link_feature_extractor.h"], hdrs = ["syntaxnet_link_feature_extractor.h"],
deps = [ deps = [
"//dragnn/protos:spec_proto", "//dragnn/protos:spec_proto_cc",
"//syntaxnet:base", "//syntaxnet:base",
"//syntaxnet:embedding_feature_extractor", "//syntaxnet:embedding_feature_extractor",
"//syntaxnet:parser_transitions", "//syntaxnet:parser_transitions",
...@@ -53,7 +55,7 @@ cc_library( ...@@ -53,7 +55,7 @@ cc_library(
"//dragnn/core/interfaces:cloneable_transition_state", "//dragnn/core/interfaces:cloneable_transition_state",
"//dragnn/core/interfaces:transition_state", "//dragnn/core/interfaces:transition_state",
"//dragnn/io:syntaxnet_sentence", "//dragnn/io:syntaxnet_sentence",
"//dragnn/protos:trace_proto", "//dragnn/protos:trace_proto_cc",
"//syntaxnet:base", "//syntaxnet:base",
"//syntaxnet:parser_transitions", "//syntaxnet:parser_transitions",
], ],
...@@ -77,7 +79,7 @@ cc_test( ...@@ -77,7 +79,7 @@ cc_test(
"//dragnn/core/test:mock_transition_state", "//dragnn/core/test:mock_transition_state",
"//dragnn/io:sentence_input_batch", "//dragnn/io:sentence_input_batch",
"//syntaxnet:base", "//syntaxnet:base",
"//syntaxnet:sentence_proto", "//syntaxnet:sentence_proto_cc",
"//syntaxnet:test_main", "//syntaxnet:test_main",
], ],
) )
...@@ -88,7 +90,7 @@ cc_test( ...@@ -88,7 +90,7 @@ cc_test(
deps = [ deps = [
":syntaxnet_link_feature_extractor", ":syntaxnet_link_feature_extractor",
"//dragnn/core/test:generic", "//dragnn/core/test:generic",
"//dragnn/protos:spec_proto", "//dragnn/protos:spec_proto_cc",
"//syntaxnet:task_context", "//syntaxnet:task_context",
"//syntaxnet:test_main", "//syntaxnet:test_main",
], ],
...@@ -105,9 +107,9 @@ cc_test( ...@@ -105,9 +107,9 @@ cc_test(
"//dragnn/core/test:generic", "//dragnn/core/test:generic",
"//dragnn/core/test:mock_transition_state", "//dragnn/core/test:mock_transition_state",
"//dragnn/io:sentence_input_batch", "//dragnn/io:sentence_input_batch",
"//dragnn/protos:spec_proto", "//dragnn/protos:spec_proto_cc",
"//syntaxnet:base", "//syntaxnet:base",
"//syntaxnet:sentence_proto", "//syntaxnet:sentence_proto_cc",
"//syntaxnet:test_main", "//syntaxnet:test_main",
], ],
) )
...@@ -22,6 +22,7 @@ ...@@ -22,6 +22,7 @@
#include "dragnn/core/input_batch_cache.h" #include "dragnn/core/input_batch_cache.h"
#include "dragnn/core/interfaces/component.h" #include "dragnn/core/interfaces/component.h"
#include "dragnn/core/interfaces/transition_state.h" #include "dragnn/core/interfaces/transition_state.h"
#include "dragnn/core/util/label.h"
#include "dragnn/io/sentence_input_batch.h" #include "dragnn/io/sentence_input_batch.h"
#include "dragnn/io/syntaxnet_sentence.h" #include "dragnn/io/syntaxnet_sentence.h"
#include "syntaxnet/parser_state.h" #include "syntaxnet/parser_state.h"
...@@ -29,13 +30,12 @@ ...@@ -29,13 +30,12 @@
#include "syntaxnet/task_spec.pb.h" #include "syntaxnet/task_spec.pb.h"
#include "syntaxnet/utils.h" #include "syntaxnet/utils.h"
#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/logging.h"
#include "util/utf8/unicodetext.h"
namespace syntaxnet { namespace syntaxnet {
namespace dragnn { namespace dragnn {
using tensorflow::strings::StrCat;
namespace { namespace {
// Returns a new step in a trace based on a ComponentSpec. // Returns a new step in a trace based on a ComponentSpec.
...@@ -103,7 +103,7 @@ void SyntaxNetComponent::InitializeComponent(const ComponentSpec &spec) { ...@@ -103,7 +103,7 @@ void SyntaxNetComponent::InitializeComponent(const ComponentSpec &spec) {
names.push_back(channel.name()); names.push_back(channel.name());
fml.push_back(channel.fml()); fml.push_back(channel.fml());
predicate_maps.push_back(channel.predicate_map()); predicate_maps.push_back(channel.predicate_map());
dims.push_back(StrCat(channel.embedding_dim())); dims.push_back(tensorflow::strings::StrCat(channel.embedding_dim()));
} }
...@@ -125,7 +125,7 @@ void SyntaxNetComponent::InitializeComponent(const ComponentSpec &spec) { ...@@ -125,7 +125,7 @@ void SyntaxNetComponent::InitializeComponent(const ComponentSpec &spec) {
for (const LinkedFeatureChannel &channel : spec.linked_feature()) { for (const LinkedFeatureChannel &channel : spec.linked_feature()) {
names.push_back(channel.name()); names.push_back(channel.name());
fml.push_back(channel.fml()); fml.push_back(channel.fml());
dims.push_back(StrCat(channel.embedding_dim())); dims.push_back(tensorflow::strings::StrCat(channel.embedding_dim()));
source_components.push_back(channel.source_component()); source_components.push_back(channel.source_component());
source_layers.push_back(channel.source_layer()); source_layers.push_back(channel.source_layer());
source_translators.push_back(channel.source_translator()); source_translators.push_back(channel.source_translator());
...@@ -332,6 +332,22 @@ std::function<int(int, int, int)> SyntaxNetComponent::GetStepLookupFunction( ...@@ -332,6 +332,22 @@ std::function<int(int, int, int)> SyntaxNetComponent::GetStepLookupFunction(
return -1; return -1;
} }
}; };
} else if (method == "reverse-char") {
// Reverses the character-level index.
return [this](int batch_index, int beam_index, int value) {
SyntaxNetTransitionState *state =
batch_.at(batch_index)->beam_state(beam_index);
const auto *sentence = state->sentence()->sentence();
const string &text = sentence->text();
const int start_byte = sentence->token(0).start();
const int end_byte = sentence->token(sentence->token_size() - 1).end();
UnicodeText unicode;
unicode.PointToUTF8(text.data() + start_byte, end_byte - start_byte + 1);
const int num_chars = distance(unicode.begin(), unicode.end());
const int result = num_chars - value - 1;
if (result >= 0 && result < num_chars) return result;
return -1;
};
} else { } else {
LOG(FATAL) << "Unable to find step lookup function " << method; LOG(FATAL) << "Unable to find step lookup function " << method;
} }
...@@ -418,12 +434,12 @@ int SyntaxNetComponent::GetFixedFeatures( ...@@ -418,12 +434,12 @@ int SyntaxNetComponent::GetFixedFeatures(
const bool has_weights = f.weight_size() != 0; const bool has_weights = f.weight_size() != 0;
for (int i = 0; i < f.description_size(); ++i) { for (int i = 0; i < f.description_size(); ++i) {
if (has_weights) { if (has_weights) {
fixed_features.add_value_name(StrCat("id: ", f.id(i), fixed_features.add_value_name(tensorflow::strings::StrCat(
" name: ", f.description(i), "id: ", f.id(i), " name: ", f.description(i),
" weight: ", f.weight(i))); " weight: ", f.weight(i)));
} else { } else {
fixed_features.add_value_name( fixed_features.add_value_name(tensorflow::strings::StrCat(
StrCat("id: ", f.id(i), " name: ", f.description(i))); "id: ", f.id(i), " name: ", f.description(i)));
} }
} }
fixed_features.set_feature_name(""); fixed_features.set_feature_name("");
...@@ -615,16 +631,19 @@ std::vector<LinkFeatures> SyntaxNetComponent::GetRawLinkFeatures( ...@@ -615,16 +631,19 @@ std::vector<LinkFeatures> SyntaxNetComponent::GetRawLinkFeatures(
return features; return features;
} }
std::vector<std::vector<int>> SyntaxNetComponent::GetOracleLabels() const { std::vector<std::vector<std::vector<Label>>>
std::vector<std::vector<int>> oracle_labels; SyntaxNetComponent::GetOracleLabels() const {
for (const auto &beam : batch_) { std::vector<std::vector<std::vector<Label>>> oracle_labels(batch_.size());
oracle_labels.emplace_back(); for (int batch_idx = 0; batch_idx < batch_.size(); ++batch_idx) {
const auto &beam = batch_[batch_idx];
std::vector<std::vector<Label>> &output_beam = oracle_labels[batch_idx];
for (int beam_idx = 0; beam_idx < beam->size(); ++beam_idx) { for (int beam_idx = 0; beam_idx < beam->size(); ++beam_idx) {
// Get the raw link features from the linked feature extractor. // Get the raw link features from the linked feature extractor.
auto state = beam->beam_state(beam_idx); auto state = beam->beam_state(beam_idx);
// Arbitrarily choose the first vector element. // Arbitrarily choose the first vector element.
oracle_labels.back().push_back(GetOracleVector(state).front()); output_beam.emplace_back();
output_beam.back().emplace_back(GetOracleVector(state).front());
} }
} }
return oracle_labels; return oracle_labels;
......
...@@ -25,6 +25,7 @@ ...@@ -25,6 +25,7 @@
#include "dragnn/core/input_batch_cache.h" #include "dragnn/core/input_batch_cache.h"
#include "dragnn/core/interfaces/component.h" #include "dragnn/core/interfaces/component.h"
#include "dragnn/core/interfaces/transition_state.h" #include "dragnn/core/interfaces/transition_state.h"
#include "dragnn/core/util/label.h"
#include "dragnn/protos/data.pb.h" #include "dragnn/protos/data.pb.h"
#include "dragnn/protos/spec.pb.h" #include "dragnn/protos/spec.pb.h"
#include "dragnn/protos/trace.pb.h" #include "dragnn/protos/trace.pb.h"
...@@ -113,13 +114,24 @@ class SyntaxNetComponent : public Component { ...@@ -113,13 +114,24 @@ class SyntaxNetComponent : public Component {
LOG(FATAL) << "Method not supported"; LOG(FATAL) << "Method not supported";
} }
void BulkEmbedDenseFixedFeatures(
const vector<const float *> &per_channel_embeddings,
float *embedding_output, int embedding_output_size,
int32 *offset_array_output, int offset_array_size) override {
LOG(FATAL) << "Method not supported";
}
int BulkDenseFeatureSize() const override {
LOG(FATAL) << "Method not supported";
}
// Extracts and returns the vector of LinkFeatures for the specified // Extracts and returns the vector of LinkFeatures for the specified
// channel. Note: these are NOT translated. // channel. Note: these are NOT translated.
std::vector<LinkFeatures> GetRawLinkFeatures(int channel_id) const override; std::vector<LinkFeatures> GetRawLinkFeatures(int channel_id) const override;
// Returns a vector of oracle labels for each element in the beam and // Returns a vector of oracle labels for each element in the beam and
// batch. // batch.
std::vector<std::vector<int>> GetOracleLabels() const override; std::vector<std::vector<std::vector<Label>>> GetOracleLabels() const override;
// Annotate the underlying data object with the results of this Component's // Annotate the underlying data object with the results of this Component's
// calculation. // calculation.
......
...@@ -40,6 +40,7 @@ namespace dragnn { ...@@ -40,6 +40,7 @@ namespace dragnn {
namespace { namespace {
const char kSentence0[] = R"( const char kSentence0[] = R"(
text: "Sentence 0."
token { token {
word: "Sentence" start: 0 end: 7 tag: "NN" category: "NOUN" label: "ROOT" word: "Sentence" start: 0 end: 7 tag: "NN" category: "NOUN" label: "ROOT"
break_level: NO_BREAK break_level: NO_BREAK
...@@ -55,6 +56,7 @@ token { ...@@ -55,6 +56,7 @@ token {
)"; )";
const char kSentence1[] = R"( const char kSentence1[] = R"(
text: "Sentence 1."
token { token {
word: "Sentence" start: 0 end: 7 tag: "NN" category: "NOUN" label: "ROOT" word: "Sentence" start: 0 end: 7 tag: "NN" category: "NOUN" label: "ROOT"
break_level: NO_BREAK break_level: NO_BREAK
...@@ -70,6 +72,7 @@ token { ...@@ -70,6 +72,7 @@ token {
)"; )";
const char kLongSentence[] = R"( const char kLongSentence[] = R"(
text: "Sentence 123."
token { token {
word: "Sentence" start: 0 end: 7 tag: "NN" category: "NOUN" label: "ROOT" word: "Sentence" start: 0 end: 7 tag: "NN" category: "NOUN" label: "ROOT"
break_level: NO_BREAK break_level: NO_BREAK
...@@ -1310,5 +1313,30 @@ TEST_F(SyntaxNetComponentTest, BulkEmbedFixedFeaturesIsNotSupported) { ...@@ -1310,5 +1313,30 @@ TEST_F(SyntaxNetComponentTest, BulkEmbedFixedFeaturesIsNotSupported) {
"Method not supported"); "Method not supported");
} }
TEST_F(SyntaxNetComponentTest, GetStepLookupFunction) {
Sentence sentence_0;
TextFormat::ParseFromString(kSentence0, &sentence_0);
string sentence_0_str;
sentence_0.SerializeToString(&sentence_0_str);
constexpr int kBeamSize = 1;
auto test_parser = CreateParserWithBeamSize(kBeamSize, {}, {sentence_0_str});
ASSERT_TRUE(test_parser->IsReady());
const auto reverse_token_lookup =
test_parser->GetStepLookupFunction("reverse-token");
const int kNumTokens = sentence_0.token_size();
for (int i = 0; i < kNumTokens; ++i) {
EXPECT_EQ(i, reverse_token_lookup(0, 0, kNumTokens - i - 1));
}
const auto reverse_char_lookup =
test_parser->GetStepLookupFunction("reverse-char");
const int kNumChars = sentence_0.text().size(); // assumes ASCII
for (int i = 0; i < kNumChars; ++i) {
EXPECT_EQ(i, reverse_char_lookup(0, 0, kNumChars - i - 1));
}
}
} // namespace dragnn } // namespace dragnn
} // namespace syntaxnet } // namespace syntaxnet
...@@ -2,8 +2,9 @@ py_binary( ...@@ -2,8 +2,9 @@ py_binary(
name = "make_parser_spec", name = "make_parser_spec",
srcs = ["make_parser_spec.py"], srcs = ["make_parser_spec.py"],
deps = [ deps = [
"//dragnn/protos:spec_py_pb2", "//dragnn/protos:spec_pb2_py",
"//dragnn/python:spec_builder", "//dragnn/python:spec_builder",
"@absl_py//absl/flags",
"@org_tensorflow//tensorflow:tensorflow_py", "@org_tensorflow//tensorflow:tensorflow_py",
], ],
) )
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
# ============================================================================== # ==============================================================================
"""Construct the spec for the CONLL2017 Parser baseline.""" """Construct the spec for the CONLL2017 Parser baseline."""
from absl import flags
import tensorflow as tf import tensorflow as tf
from tensorflow.python.platform import gfile from tensorflow.python.platform import gfile
...@@ -21,7 +22,6 @@ from tensorflow.python.platform import gfile ...@@ -21,7 +22,6 @@ from tensorflow.python.platform import gfile
from dragnn.protos import spec_pb2 from dragnn.protos import spec_pb2
from dragnn.python import spec_builder from dragnn.python import spec_builder
flags = tf.app.flags
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
flags.DEFINE_string('spec_file', 'parser_spec.textproto', flags.DEFINE_string('spec_file', 'parser_spec.textproto',
......
...@@ -37,8 +37,9 @@ cc_library( ...@@ -37,8 +37,9 @@ cc_library(
":input_batch_cache", ":input_batch_cache",
"//dragnn/components/util:bulk_feature_extractor", "//dragnn/components/util:bulk_feature_extractor",
"//dragnn/core/interfaces:component", "//dragnn/core/interfaces:component",
"//dragnn/protos:spec_proto", "//dragnn/core/util:label",
"//dragnn/protos:trace_proto", "//dragnn/protos:spec_proto_cc",
"//dragnn/protos:trace_proto_cc",
], ],
) )
...@@ -51,9 +52,10 @@ cc_library( ...@@ -51,9 +52,10 @@ cc_library(
":index_translator", ":index_translator",
":input_batch_cache", ":input_batch_cache",
"//dragnn/components/util:bulk_feature_extractor", "//dragnn/components/util:bulk_feature_extractor",
"//dragnn/protos:data_proto", "//dragnn/core/util:label",
"//dragnn/protos:spec_proto", "//dragnn/protos:data_proto_cc",
"//dragnn/protos:trace_proto", "//dragnn/protos:spec_proto_cc",
"//dragnn/protos:trace_proto_cc",
"//syntaxnet:base", "//syntaxnet:base",
"//syntaxnet:registry", "//syntaxnet:registry",
], ],
...@@ -67,7 +69,7 @@ cc_library( ...@@ -67,7 +69,7 @@ cc_library(
":component_registry", ":component_registry",
":compute_session", ":compute_session",
":compute_session_impl", ":compute_session_impl",
"//dragnn/protos:spec_proto", "//dragnn/protos:spec_proto_cc",
"//syntaxnet:base", "//syntaxnet:base",
], ],
) )
...@@ -125,10 +127,13 @@ cc_test( ...@@ -125,10 +127,13 @@ cc_test(
"//dragnn/components/util:bulk_feature_extractor", "//dragnn/components/util:bulk_feature_extractor",
"//dragnn/core/interfaces:component", "//dragnn/core/interfaces:component",
"//dragnn/core/interfaces:input_batch", "//dragnn/core/interfaces:input_batch",
"//dragnn/core/test:fake_component_base",
"//dragnn/core/test:generic", "//dragnn/core/test:generic",
"//dragnn/core/test:mock_component", "//dragnn/core/test:mock_component",
"//dragnn/core/test:mock_transition_state", "//dragnn/core/test:mock_transition_state",
"//dragnn/core/util:label",
"//syntaxnet:base", "//syntaxnet:base",
"@org_tensorflow//tensorflow/core:test",
], ],
) )
...@@ -182,14 +187,24 @@ cc_test( ...@@ -182,14 +187,24 @@ cc_test(
# Tensorflow op kernel BUILD rules. # Tensorflow op kernel BUILD rules.
load( load(
"//dragnn:tensorflow_ops.bzl", "@org_tensorflow//tensorflow:tensorflow.bzl",
"tf_gen_op_libs", "tf_gen_op_libs",
"tf_gen_op_wrapper_py", "tf_gen_op_wrapper_py",
"tf_kernel_library", "tf_kernel_library",
) )
cc_library(
name = "shape_helpers",
hdrs = ["ops/shape_helpers.h"],
deps = [
"//syntaxnet:shape_helpers",
"@org_tensorflow//tensorflow/core:framework_headers_lib",
],
)
tf_gen_op_libs( tf_gen_op_libs(
op_lib_names = ["dragnn_ops"], op_lib_names = ["dragnn_ops"],
deps = [":shape_helpers"],
) )
tf_gen_op_wrapper_py( tf_gen_op_wrapper_py(
...@@ -199,6 +214,7 @@ tf_gen_op_wrapper_py( ...@@ -199,6 +214,7 @@ tf_gen_op_wrapper_py(
tf_gen_op_libs( tf_gen_op_libs(
op_lib_names = ["dragnn_bulk_ops"], op_lib_names = ["dragnn_bulk_ops"],
deps = [":shape_helpers"],
) )
tf_gen_op_wrapper_py( tf_gen_op_wrapper_py(
...@@ -231,8 +247,10 @@ cc_library( ...@@ -231,8 +247,10 @@ cc_library(
":compute_session_op", ":compute_session_op",
":compute_session_pool", ":compute_session_pool",
":resource_container", ":resource_container",
"//dragnn/protos:data_proto", ":shape_helpers",
"//dragnn/protos:spec_proto", "//dragnn/core/util:label",
"//dragnn/protos:data_proto_cc",
"//dragnn/protos:spec_proto_cc",
"//syntaxnet:base", "//syntaxnet:base",
"@org_tensorflow//third_party/eigen3", "@org_tensorflow//third_party/eigen3",
], ],
...@@ -248,6 +266,8 @@ cc_library( ...@@ -248,6 +266,8 @@ cc_library(
deps = [ deps = [
":compute_session_op", ":compute_session_op",
":resource_container", ":resource_container",
":shape_helpers",
"//dragnn/core/util:label",
"//syntaxnet:base", "//syntaxnet:base",
"@org_tensorflow//third_party/eigen3", "@org_tensorflow//third_party/eigen3",
], ],
...@@ -269,8 +289,10 @@ tf_kernel_library( ...@@ -269,8 +289,10 @@ tf_kernel_library(
":compute_session_op", ":compute_session_op",
":compute_session_pool", ":compute_session_pool",
":resource_container", ":resource_container",
"//dragnn/protos:data_proto", ":shape_helpers",
"//dragnn/protos:spec_proto", "//dragnn/core/util:label",
"//dragnn/protos:data_proto_cc",
"//dragnn/protos:spec_proto_cc",
"//syntaxnet:base", "//syntaxnet:base",
"@org_tensorflow//third_party/eigen3", "@org_tensorflow//third_party/eigen3",
], ],
...@@ -289,8 +311,10 @@ tf_kernel_library( ...@@ -289,8 +311,10 @@ tf_kernel_library(
":compute_session_op", ":compute_session_op",
":compute_session_pool", ":compute_session_pool",
":resource_container", ":resource_container",
":shape_helpers",
"//dragnn/components/util:bulk_feature_extractor", "//dragnn/components/util:bulk_feature_extractor",
"//dragnn/protos:spec_proto", "//dragnn/core/util:label",
"//dragnn/protos:spec_proto_cc",
"//syntaxnet:base", "//syntaxnet:base",
"@org_tensorflow//tensorflow/core:protos_all_cc", "@org_tensorflow//tensorflow/core:protos_all_cc",
"@org_tensorflow//third_party/eigen3", "@org_tensorflow//third_party/eigen3",
...@@ -309,6 +333,7 @@ cc_test( ...@@ -309,6 +333,7 @@ cc_test(
":resource_container", ":resource_container",
"//dragnn/core/test:generic", "//dragnn/core/test:generic",
"//dragnn/core/test:mock_compute_session", "//dragnn/core/test:mock_compute_session",
"//dragnn/core/util:label",
"//syntaxnet:base", "//syntaxnet:base",
"//syntaxnet:test_main", "//syntaxnet:test_main",
"@org_tensorflow//tensorflow/core:protos_all_cc", "@org_tensorflow//tensorflow/core:protos_all_cc",
...@@ -327,6 +352,7 @@ cc_test( ...@@ -327,6 +352,7 @@ cc_test(
":resource_container", ":resource_container",
"//dragnn/components/util:bulk_feature_extractor", "//dragnn/components/util:bulk_feature_extractor",
"//dragnn/core/test:mock_compute_session", "//dragnn/core/test:mock_compute_session",
"//dragnn/core/util:label",
"//syntaxnet:base", "//syntaxnet:base",
"//syntaxnet:test_main", "//syntaxnet:test_main",
"@org_tensorflow//tensorflow/core/kernels:ops_testutil", "@org_tensorflow//tensorflow/core/kernels:ops_testutil",
......
...@@ -22,6 +22,7 @@ ...@@ -22,6 +22,7 @@
#include "dragnn/core/index_translator.h" #include "dragnn/core/index_translator.h"
#include "dragnn/core/input_batch_cache.h" #include "dragnn/core/input_batch_cache.h"
#include "dragnn/core/interfaces/component.h" #include "dragnn/core/interfaces/component.h"
#include "dragnn/core/util/label.h"
#include "dragnn/protos/spec.pb.h" #include "dragnn/protos/spec.pb.h"
#include "dragnn/protos/trace.pb.h" #include "dragnn/protos/trace.pb.h"
...@@ -102,7 +103,7 @@ class ComputeSession { ...@@ -102,7 +103,7 @@ class ComputeSession {
const string &component_name, int channel_id) = 0; const string &component_name, int channel_id) = 0;
// Get the oracle labels for the given component. // Get the oracle labels for the given component.
virtual std::vector<std::vector<int>> EmitOracleLabels( virtual std::vector<std::vector<std::vector<Label>>> EmitOracleLabels(
const string &component_name) = 0; const string &component_name) = 0;
// Returns true if the given component is terminal. // Returns true if the given component is terminal.
...@@ -126,6 +127,9 @@ class ComputeSession { ...@@ -126,6 +127,9 @@ class ComputeSession {
// bypassing de-serialization. // bypassing de-serialization.
virtual void SetInputBatchCache(std::unique_ptr<InputBatchCache> batch) = 0; virtual void SetInputBatchCache(std::unique_ptr<InputBatchCache> batch) = 0;
// Returns the current InputBatchCache, or null if there is none.
virtual InputBatchCache *GetInputBatchCache() = 0;
// Resets all components owned by this ComputeSession. // Resets all components owned by this ComputeSession.
virtual void ResetSession() = 0; virtual void ResetSession() = 0;
......
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
#include <algorithm> #include <algorithm>
#include <utility> #include <utility>
#include "dragnn/core/util/label.h"
#include "dragnn/protos/data.pb.h" #include "dragnn/protos/data.pb.h"
#include "dragnn/protos/spec.pb.h" #include "dragnn/protos/spec.pb.h"
#include "dragnn/protos/trace.pb.h" #include "dragnn/protos/trace.pb.h"
...@@ -123,8 +124,12 @@ void ComputeSessionImpl::InitializeComponentData(const string &component_name, ...@@ -123,8 +124,12 @@ void ComputeSessionImpl::InitializeComponentData(const string &component_name,
VLOG(1) << "Source result found. Using prior initialization vector for " VLOG(1) << "Source result found. Using prior initialization vector for "
<< component_name; << component_name;
auto source = source_result->second; auto source = source_result->second;
CHECK(source->IsTerminal()) << "Source is not terminal for component '" CHECK(source->IsTerminal())
<< component_name << "'. Exiting."; << "Source component '" << source->Name()
<< "' for currently active component '" << component_name
<< "' is not terminal. "
<< "Are you using bulk feature extraction with only linked features? "
<< "If so, consider using the StatelessComponent instead. Exiting.";
component->InitializeData(source->GetBeam(), max_beam_size, component->InitializeData(source->GetBeam(), max_beam_size,
input_data_.get()); input_data_.get());
} }
...@@ -219,8 +224,8 @@ std::vector<LinkFeatures> ComputeSessionImpl::GetTranslatedLinkFeatures( ...@@ -219,8 +224,8 @@ std::vector<LinkFeatures> ComputeSessionImpl::GetTranslatedLinkFeatures(
return features; return features;
} }
std::vector<std::vector<int>> ComputeSessionImpl::EmitOracleLabels( std::vector<std::vector<std::vector<Label>>>
const string &component_name) { ComputeSessionImpl::EmitOracleLabels(const string &component_name) {
return GetReadiedComponent(component_name)->GetOracleLabels(); return GetReadiedComponent(component_name)->GetOracleLabels();
} }
...@@ -303,6 +308,10 @@ void ComputeSessionImpl::SetInputBatchCache( ...@@ -303,6 +308,10 @@ void ComputeSessionImpl::SetInputBatchCache(
input_data_ = std::move(batch); input_data_ = std::move(batch);
} }
InputBatchCache *ComputeSessionImpl::GetInputBatchCache() {
return input_data_.get();
}
void ComputeSessionImpl::ResetSession() { void ComputeSessionImpl::ResetSession() {
// Reset all component states. // Reset all component states.
for (auto &component_pair : components_) { for (auto &component_pair : components_) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment