Unverified Commit 80178fc6 authored by Mark Omernick's avatar Mark Omernick Committed by GitHub
Browse files

Merge pull request #4153 from terryykoo/master

Export @195097388.
parents a84e1ef9 edea2b67
FROM ubuntu:16.10
FROM ubuntu:16.04
ENV SYNTAXNETDIR=/opt/tensorflow PATH=$PATH:/root/bin
......@@ -57,10 +57,10 @@ RUN python -m pip install \
&& rm -rf /root/.cache/pip /tmp/pip*
# Installs Bazel.
RUN wget --quiet https://github.com/bazelbuild/bazel/releases/download/0.8.1/bazel-0.8.1-installer-linux-x86_64.sh \
&& chmod +x bazel-0.8.1-installer-linux-x86_64.sh \
&& ./bazel-0.8.1-installer-linux-x86_64.sh \
&& rm ./bazel-0.8.1-installer-linux-x86_64.sh
RUN wget --quiet https://github.com/bazelbuild/bazel/releases/download/0.11.1/bazel-0.11.1-installer-linux-x86_64.sh \
&& chmod +x bazel-0.11.1-installer-linux-x86_64.sh \
&& ./bazel-0.11.1-installer-linux-x86_64.sh \
&& rm ./bazel-0.11.1-installer-linux-x86_64.sh
COPY WORKSPACE $SYNTAXNETDIR/syntaxnet/WORKSPACE
COPY tools/bazel.rc $SYNTAXNETDIR/syntaxnet/tools/bazel.rc
......
......@@ -60,10 +60,10 @@ The simplest way to get started with DRAGNN is by loading our Docker container.
[Here](g3doc/CLOUD.md) is a tutorial for running the DRAGNN container on
[GCP](https://cloud.google.com) (just as applicable to your own computer).
### Ubuntu 16.10+ binary installation
### Ubuntu 16.04+ binary installation
_This process takes ~5 minutes, but is only compatible with Linux using GNU libc
3.4.22 and above (e.g. Ubuntu 16.10)._
3.4.22 and above (e.g. Ubuntu 16.04)._
Binary wheel packages are provided for TensorFlow and SyntaxNet. If you do not
need to write new binary TensorFlow ops, these should suffice.
......@@ -92,9 +92,9 @@ source. You'll need to install:
* python 2.7:
* Python 3 support is not available yet
* bazel 0.5.4:
* bazel 0.11.1:
* Follow the instructions [here](http://bazel.build/docs/install.html)
* Alternately, Download bazel 0.5.4 <.deb> from
* Alternately, Download bazel 0.11.1 <.deb> from
[https://github.com/bazelbuild/bazel/releases](https://github.com/bazelbuild/bazel/releases)
for your system configuration.
* Install it using the command: sudo dpkg -i <.deb file>
......@@ -105,14 +105,14 @@ source. You'll need to install:
* protocol buffers, with a version supported by TensorFlow:
* check your protobuf version with `pip freeze | grep protobuf`
* upgrade to a supported version with `pip install -U protobuf==3.3.0`
* autograd, with a version supported by TensorFlow:
* `pip install -U autograd==1.1.13`
* mock, the testing package:
* `pip install mock`
* asciitree, to draw parse trees on the console for the demo:
* `pip install asciitree`
* numpy, package for scientific computing:
* `pip install numpy`
* autograd 1.1.13, for automatic differentiation (not yet compatible with autograd v1.2 rewrite):
* `pip install autograd==1.1.13`
* pygraphviz to visualize traces and parse trees:
* `apt-get install -y graphviz libgraphviz-dev`
* `pip install pygraphviz
......
......@@ -9,22 +9,33 @@ local_repository(
# @io_bazel_rules_closure.
http_archive(
name = "io_bazel_rules_closure",
sha256 = "25f5399f18d8bf9ce435f85c6bbf671ec4820bc4396b3022cc5dc4bc66303609",
strip_prefix = "rules_closure-0.4.2",
sha256 = "6691c58a2cd30a86776dd9bb34898b041e37136f2dc7e24cadaeaf599c95c657",
strip_prefix = "rules_closure-08039ba8ca59f64248bb3b6ae016460fe9c9914f",
urls = [
"http://bazel-mirror.storage.googleapis.com/github.com/bazelbuild/rules_closure/archive/0.4.2.tar.gz", # 2017-08-30
"https://github.com/bazelbuild/rules_closure/archive/0.4.2.tar.gz",
"http://bazel-mirror.storage.googleapis.com/github.com/bazelbuild/rules_closure/archive/08039ba8ca59f64248bb3b6ae016460fe9c9914f.tar.gz",
"https://github.com/bazelbuild/rules_closure/archive/08039ba8ca59f64248bb3b6ae016460fe9c9914f.tar.gz",
],
)
load("@org_tensorflow//tensorflow:workspace.bzl", "tf_workspace")
tf_workspace(path_prefix="", tf_repo_name="org_tensorflow")
# Test that Bazel is up-to-date.
load("@org_tensorflow//tensorflow:workspace.bzl", "check_version")
check_version("0.4.2")
tf_workspace(
path_prefix = "",
tf_repo_name = "org_tensorflow",
)
http_archive(
name = "sling",
sha256 = "f1ce597476cb024808ca0a371a01db9dda4e0c58fb34a4f9c4ea91796f437b10",
strip_prefix = "sling-e3ae9d94eb1d9ee037a851070d54ed2eefaa928a",
urls = [
"http://bazel-mirror.storage.googleapis.com/github.com/google/sling/archive/e3ae9d94eb1d9ee037a851070d54ed2eefaa928a.tar.gz",
"https://github.com/google/sling/archive/e3ae9d94eb1d9ee037a851070d54ed2eefaa928a.tar.gz",
],
)
# Used by SLING.
bind(
name = "protobuf",
actual = "@protobuf_archive//:protobuf",
name = "zlib",
actual = "@zlib_archive//:zlib",
)
......@@ -9,3 +9,4 @@ COPY dragnn $SYNTAXNETDIR/syntaxnet/dragnn
COPY syntaxnet $SYNTAXNETDIR/syntaxnet/syntaxnet
COPY third_party $SYNTAXNETDIR/syntaxnet/third_party
COPY util/utf8 $SYNTAXNETDIR/syntaxnet/util/utf8
COPY WORKSPACE $SYNTAXNETDIR/syntaxnet/WORKSPACE
FROM ubuntu:16.10
FROM ubuntu:16.04
ENV SYNTAXNETDIR=/opt/tensorflow PATH=$PATH:/root/bin
......@@ -57,10 +57,10 @@ RUN python -m pip install \
&& rm -rf /root/.cache/pip /tmp/pip*
# Installs Bazel.
RUN wget --quiet https://github.com/bazelbuild/bazel/releases/download/0.5.3/bazel-0.5.3-installer-linux-x86_64.sh \
&& chmod +x bazel-0.5.3-installer-linux-x86_64.sh \
&& JAVA_HOME=/usr/lib/jvm/java-8-openjdk-amd64/ ./bazel-0.5.3-installer-linux-x86_64.sh \
&& rm ./bazel-0.5.3-installer-linux-x86_64.sh
RUN wget --quiet https://github.com/bazelbuild/bazel/releases/download/0.11.1/bazel-0.11.1-installer-linux-x86_64.sh \
&& chmod +x bazel-0.11.1-installer-linux-x86_64.sh \
&& JAVA_HOME=/usr/lib/jvm/java-8-openjdk-amd64/ ./bazel-0.11.1-installer-linux-x86_64.sh \
&& rm ./bazel-0.11.1-installer-linux-x86_64.sh
COPY WORKSPACE $SYNTAXNETDIR/syntaxnet/WORKSPACE
COPY tools/bazel.rc $SYNTAXNETDIR/syntaxnet/tools/bazel.rc
......@@ -69,12 +69,9 @@ COPY tools/bazel.rc $SYNTAXNETDIR/syntaxnet/tools/bazel.rc
# source. This makes it more convenient to re-compile DRAGNN / SyntaxNet for
# development (though not as convenient as the docker-devel scripts).
RUN cd $SYNTAXNETDIR/syntaxnet \
&& git clone --branch r1.3 --recurse-submodules https://github.com/tensorflow/tensorflow \
&& git clone --branch r1.8 --recurse-submodules https://github.com/tensorflow/tensorflow \
&& cd tensorflow \
# This line removes a bad archive target which causes Tensorflow install
# to fail.
&& sed -i '\@https://github.com/google/protobuf/archive/0b059a3d8a8f8aa40dde7bea55edca4ec5dfea66.tar.gz@d' tensorflow/workspace.bzl \
&& tensorflow/tools/ci_build/builds/configured CPU \\
&& tensorflow/tools/ci_build/builds/configured CPU \
&& cd $SYNTAXNETDIR/syntaxnet \
&& bazel build -c opt @org_tensorflow//tensorflow:tensorflow_py
......
......@@ -3,7 +3,7 @@
#
# It might be more efficient to use a minimal distribution, like Alpine. But
# the upside of this being popular is that people might already have it.
FROM ubuntu:16.10
FROM ubuntu:16.04
ENV SYNTAXNETDIR=/opt/tensorflow PATH=$PATH:/root/bin
......
......@@ -10,7 +10,8 @@ cc_library(
"//dragnn/core:component_registry",
"//dragnn/core/interfaces:component",
"//dragnn/core/interfaces:transition_state",
"//dragnn/protos:data_proto",
"//dragnn/core/util:label",
"//dragnn/protos:data_proto_cc",
"//syntaxnet:base",
],
alwayslink = 1,
......@@ -27,7 +28,7 @@ cc_test(
"//dragnn/core/test:mock_transition_state",
"//dragnn/io:sentence_input_batch",
"//syntaxnet:base",
"//syntaxnet:sentence_proto",
"//syntaxnet:sentence_proto_cc",
"//syntaxnet:test_main",
],
)
......@@ -16,6 +16,7 @@
#include "dragnn/core/component_registry.h"
#include "dragnn/core/interfaces/component.h"
#include "dragnn/core/interfaces/transition_state.h"
#include "dragnn/core/util/label.h"
#include "dragnn/protos/data.pb.h"
#include "syntaxnet/base.h"
......@@ -90,7 +91,8 @@ class StatelessComponent : public Component {
void AdvanceFromOracle() override {
LOG(FATAL) << "[" << name_ << "] AdvanceFromOracle not supported";
}
std::vector<std::vector<int>> GetOracleLabels() const override {
std::vector<std::vector<std::vector<Label>>> GetOracleLabels()
const override {
LOG(FATAL) << "[" << name_ << "] Method not supported";
}
int GetFixedFeatures(std::function<int32 *(int)> allocate_indices,
......@@ -108,7 +110,15 @@ class StatelessComponent : public Component {
float *embedding_output) override {
LOG(FATAL) << "[" << name_ << "] Method not supported";
}
void BulkEmbedDenseFixedFeatures(
const vector<const float *> &per_channel_embeddings,
float *embedding_output, int embedding_output_size,
int32 *offset_array_output, int offset_array_size) override {
LOG(FATAL) << "[" << name_ << "] Method not supported";
}
int BulkDenseFeatureSize() const override {
LOG(FATAL) << "Method not supported";
}
std::vector<LinkFeatures> GetRawLinkFeatures(int channel_id) const override {
LOG(FATAL) << "[" << name_ << "] Method not supported";
}
......
......@@ -16,18 +16,20 @@ cc_library(
"//dragnn/core:input_batch_cache",
"//dragnn/core/interfaces:component",
"//dragnn/core/interfaces:transition_state",
"//dragnn/core/util:label",
"//dragnn/io:sentence_input_batch",
"//dragnn/io:syntaxnet_sentence",
"//dragnn/protos:data_proto",
"//dragnn/protos:spec_proto",
"//dragnn/protos:trace_proto",
"//dragnn/protos:data_proto_cc",
"//dragnn/protos:spec_proto_cc",
"//dragnn/protos:trace_proto_cc",
"//syntaxnet:base",
"//syntaxnet:parser_transitions",
"//syntaxnet:registry",
"//syntaxnet:sparse_proto",
"//syntaxnet:sparse_proto_cc",
"//syntaxnet:task_context",
"//syntaxnet:task_spec_proto",
"//syntaxnet:task_spec_proto_cc",
"//syntaxnet:utils",
"//util/utf8:unicodetext",
],
alwayslink = 1,
)
......@@ -37,7 +39,7 @@ cc_library(
srcs = ["syntaxnet_link_feature_extractor.cc"],
hdrs = ["syntaxnet_link_feature_extractor.h"],
deps = [
"//dragnn/protos:spec_proto",
"//dragnn/protos:spec_proto_cc",
"//syntaxnet:base",
"//syntaxnet:embedding_feature_extractor",
"//syntaxnet:parser_transitions",
......@@ -53,7 +55,7 @@ cc_library(
"//dragnn/core/interfaces:cloneable_transition_state",
"//dragnn/core/interfaces:transition_state",
"//dragnn/io:syntaxnet_sentence",
"//dragnn/protos:trace_proto",
"//dragnn/protos:trace_proto_cc",
"//syntaxnet:base",
"//syntaxnet:parser_transitions",
],
......@@ -77,7 +79,7 @@ cc_test(
"//dragnn/core/test:mock_transition_state",
"//dragnn/io:sentence_input_batch",
"//syntaxnet:base",
"//syntaxnet:sentence_proto",
"//syntaxnet:sentence_proto_cc",
"//syntaxnet:test_main",
],
)
......@@ -88,7 +90,7 @@ cc_test(
deps = [
":syntaxnet_link_feature_extractor",
"//dragnn/core/test:generic",
"//dragnn/protos:spec_proto",
"//dragnn/protos:spec_proto_cc",
"//syntaxnet:task_context",
"//syntaxnet:test_main",
],
......@@ -105,9 +107,9 @@ cc_test(
"//dragnn/core/test:generic",
"//dragnn/core/test:mock_transition_state",
"//dragnn/io:sentence_input_batch",
"//dragnn/protos:spec_proto",
"//dragnn/protos:spec_proto_cc",
"//syntaxnet:base",
"//syntaxnet:sentence_proto",
"//syntaxnet:sentence_proto_cc",
"//syntaxnet:test_main",
],
)
......@@ -22,6 +22,7 @@
#include "dragnn/core/input_batch_cache.h"
#include "dragnn/core/interfaces/component.h"
#include "dragnn/core/interfaces/transition_state.h"
#include "dragnn/core/util/label.h"
#include "dragnn/io/sentence_input_batch.h"
#include "dragnn/io/syntaxnet_sentence.h"
#include "syntaxnet/parser_state.h"
......@@ -29,13 +30,12 @@
#include "syntaxnet/task_spec.pb.h"
#include "syntaxnet/utils.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/logging.h"
#include "util/utf8/unicodetext.h"
namespace syntaxnet {
namespace dragnn {
using tensorflow::strings::StrCat;
namespace {
// Returns a new step in a trace based on a ComponentSpec.
......@@ -103,7 +103,7 @@ void SyntaxNetComponent::InitializeComponent(const ComponentSpec &spec) {
names.push_back(channel.name());
fml.push_back(channel.fml());
predicate_maps.push_back(channel.predicate_map());
dims.push_back(StrCat(channel.embedding_dim()));
dims.push_back(tensorflow::strings::StrCat(channel.embedding_dim()));
}
......@@ -125,7 +125,7 @@ void SyntaxNetComponent::InitializeComponent(const ComponentSpec &spec) {
for (const LinkedFeatureChannel &channel : spec.linked_feature()) {
names.push_back(channel.name());
fml.push_back(channel.fml());
dims.push_back(StrCat(channel.embedding_dim()));
dims.push_back(tensorflow::strings::StrCat(channel.embedding_dim()));
source_components.push_back(channel.source_component());
source_layers.push_back(channel.source_layer());
source_translators.push_back(channel.source_translator());
......@@ -332,6 +332,22 @@ std::function<int(int, int, int)> SyntaxNetComponent::GetStepLookupFunction(
return -1;
}
};
} else if (method == "reverse-char") {
// Reverses the character-level index.
return [this](int batch_index, int beam_index, int value) {
SyntaxNetTransitionState *state =
batch_.at(batch_index)->beam_state(beam_index);
const auto *sentence = state->sentence()->sentence();
const string &text = sentence->text();
const int start_byte = sentence->token(0).start();
const int end_byte = sentence->token(sentence->token_size() - 1).end();
UnicodeText unicode;
unicode.PointToUTF8(text.data() + start_byte, end_byte - start_byte + 1);
const int num_chars = distance(unicode.begin(), unicode.end());
const int result = num_chars - value - 1;
if (result >= 0 && result < num_chars) return result;
return -1;
};
} else {
LOG(FATAL) << "Unable to find step lookup function " << method;
}
......@@ -418,12 +434,12 @@ int SyntaxNetComponent::GetFixedFeatures(
const bool has_weights = f.weight_size() != 0;
for (int i = 0; i < f.description_size(); ++i) {
if (has_weights) {
fixed_features.add_value_name(StrCat("id: ", f.id(i),
" name: ", f.description(i),
fixed_features.add_value_name(tensorflow::strings::StrCat(
"id: ", f.id(i), " name: ", f.description(i),
" weight: ", f.weight(i)));
} else {
fixed_features.add_value_name(
StrCat("id: ", f.id(i), " name: ", f.description(i)));
fixed_features.add_value_name(tensorflow::strings::StrCat(
"id: ", f.id(i), " name: ", f.description(i)));
}
}
fixed_features.set_feature_name("");
......@@ -615,16 +631,19 @@ std::vector<LinkFeatures> SyntaxNetComponent::GetRawLinkFeatures(
return features;
}
std::vector<std::vector<int>> SyntaxNetComponent::GetOracleLabels() const {
std::vector<std::vector<int>> oracle_labels;
for (const auto &beam : batch_) {
oracle_labels.emplace_back();
std::vector<std::vector<std::vector<Label>>>
SyntaxNetComponent::GetOracleLabels() const {
std::vector<std::vector<std::vector<Label>>> oracle_labels(batch_.size());
for (int batch_idx = 0; batch_idx < batch_.size(); ++batch_idx) {
const auto &beam = batch_[batch_idx];
std::vector<std::vector<Label>> &output_beam = oracle_labels[batch_idx];
for (int beam_idx = 0; beam_idx < beam->size(); ++beam_idx) {
// Get the raw link features from the linked feature extractor.
auto state = beam->beam_state(beam_idx);
// Arbitrarily choose the first vector element.
oracle_labels.back().push_back(GetOracleVector(state).front());
output_beam.emplace_back();
output_beam.back().emplace_back(GetOracleVector(state).front());
}
}
return oracle_labels;
......
......@@ -25,6 +25,7 @@
#include "dragnn/core/input_batch_cache.h"
#include "dragnn/core/interfaces/component.h"
#include "dragnn/core/interfaces/transition_state.h"
#include "dragnn/core/util/label.h"
#include "dragnn/protos/data.pb.h"
#include "dragnn/protos/spec.pb.h"
#include "dragnn/protos/trace.pb.h"
......@@ -113,13 +114,24 @@ class SyntaxNetComponent : public Component {
LOG(FATAL) << "Method not supported";
}
void BulkEmbedDenseFixedFeatures(
const vector<const float *> &per_channel_embeddings,
float *embedding_output, int embedding_output_size,
int32 *offset_array_output, int offset_array_size) override {
LOG(FATAL) << "Method not supported";
}
int BulkDenseFeatureSize() const override {
LOG(FATAL) << "Method not supported";
}
// Extracts and returns the vector of LinkFeatures for the specified
// channel. Note: these are NOT translated.
std::vector<LinkFeatures> GetRawLinkFeatures(int channel_id) const override;
// Returns a vector of oracle labels for each element in the beam and
// batch.
std::vector<std::vector<int>> GetOracleLabels() const override;
std::vector<std::vector<std::vector<Label>>> GetOracleLabels() const override;
// Annotate the underlying data object with the results of this Component's
// calculation.
......
......@@ -40,6 +40,7 @@ namespace dragnn {
namespace {
const char kSentence0[] = R"(
text: "Sentence 0."
token {
word: "Sentence" start: 0 end: 7 tag: "NN" category: "NOUN" label: "ROOT"
break_level: NO_BREAK
......@@ -55,6 +56,7 @@ token {
)";
const char kSentence1[] = R"(
text: "Sentence 1."
token {
word: "Sentence" start: 0 end: 7 tag: "NN" category: "NOUN" label: "ROOT"
break_level: NO_BREAK
......@@ -70,6 +72,7 @@ token {
)";
const char kLongSentence[] = R"(
text: "Sentence 123."
token {
word: "Sentence" start: 0 end: 7 tag: "NN" category: "NOUN" label: "ROOT"
break_level: NO_BREAK
......@@ -1310,5 +1313,30 @@ TEST_F(SyntaxNetComponentTest, BulkEmbedFixedFeaturesIsNotSupported) {
"Method not supported");
}
TEST_F(SyntaxNetComponentTest, GetStepLookupFunction) {
Sentence sentence_0;
TextFormat::ParseFromString(kSentence0, &sentence_0);
string sentence_0_str;
sentence_0.SerializeToString(&sentence_0_str);
constexpr int kBeamSize = 1;
auto test_parser = CreateParserWithBeamSize(kBeamSize, {}, {sentence_0_str});
ASSERT_TRUE(test_parser->IsReady());
const auto reverse_token_lookup =
test_parser->GetStepLookupFunction("reverse-token");
const int kNumTokens = sentence_0.token_size();
for (int i = 0; i < kNumTokens; ++i) {
EXPECT_EQ(i, reverse_token_lookup(0, 0, kNumTokens - i - 1));
}
const auto reverse_char_lookup =
test_parser->GetStepLookupFunction("reverse-char");
const int kNumChars = sentence_0.text().size(); // assumes ASCII
for (int i = 0; i < kNumChars; ++i) {
EXPECT_EQ(i, reverse_char_lookup(0, 0, kNumChars - i - 1));
}
}
} // namespace dragnn
} // namespace syntaxnet
......@@ -2,8 +2,9 @@ py_binary(
name = "make_parser_spec",
srcs = ["make_parser_spec.py"],
deps = [
"//dragnn/protos:spec_py_pb2",
"//dragnn/protos:spec_pb2_py",
"//dragnn/python:spec_builder",
"@absl_py//absl/flags",
"@org_tensorflow//tensorflow:tensorflow_py",
],
)
......@@ -14,6 +14,7 @@
# ==============================================================================
"""Construct the spec for the CONLL2017 Parser baseline."""
from absl import flags
import tensorflow as tf
from tensorflow.python.platform import gfile
......@@ -21,7 +22,6 @@ from tensorflow.python.platform import gfile
from dragnn.protos import spec_pb2
from dragnn.python import spec_builder
flags = tf.app.flags
FLAGS = flags.FLAGS
flags.DEFINE_string('spec_file', 'parser_spec.textproto',
......
......@@ -37,8 +37,9 @@ cc_library(
":input_batch_cache",
"//dragnn/components/util:bulk_feature_extractor",
"//dragnn/core/interfaces:component",
"//dragnn/protos:spec_proto",
"//dragnn/protos:trace_proto",
"//dragnn/core/util:label",
"//dragnn/protos:spec_proto_cc",
"//dragnn/protos:trace_proto_cc",
],
)
......@@ -51,9 +52,10 @@ cc_library(
":index_translator",
":input_batch_cache",
"//dragnn/components/util:bulk_feature_extractor",
"//dragnn/protos:data_proto",
"//dragnn/protos:spec_proto",
"//dragnn/protos:trace_proto",
"//dragnn/core/util:label",
"//dragnn/protos:data_proto_cc",
"//dragnn/protos:spec_proto_cc",
"//dragnn/protos:trace_proto_cc",
"//syntaxnet:base",
"//syntaxnet:registry",
],
......@@ -67,7 +69,7 @@ cc_library(
":component_registry",
":compute_session",
":compute_session_impl",
"//dragnn/protos:spec_proto",
"//dragnn/protos:spec_proto_cc",
"//syntaxnet:base",
],
)
......@@ -125,10 +127,13 @@ cc_test(
"//dragnn/components/util:bulk_feature_extractor",
"//dragnn/core/interfaces:component",
"//dragnn/core/interfaces:input_batch",
"//dragnn/core/test:fake_component_base",
"//dragnn/core/test:generic",
"//dragnn/core/test:mock_component",
"//dragnn/core/test:mock_transition_state",
"//dragnn/core/util:label",
"//syntaxnet:base",
"@org_tensorflow//tensorflow/core:test",
],
)
......@@ -182,14 +187,24 @@ cc_test(
# Tensorflow op kernel BUILD rules.
load(
"//dragnn:tensorflow_ops.bzl",
"@org_tensorflow//tensorflow:tensorflow.bzl",
"tf_gen_op_libs",
"tf_gen_op_wrapper_py",
"tf_kernel_library",
)
cc_library(
name = "shape_helpers",
hdrs = ["ops/shape_helpers.h"],
deps = [
"//syntaxnet:shape_helpers",
"@org_tensorflow//tensorflow/core:framework_headers_lib",
],
)
tf_gen_op_libs(
op_lib_names = ["dragnn_ops"],
deps = [":shape_helpers"],
)
tf_gen_op_wrapper_py(
......@@ -199,6 +214,7 @@ tf_gen_op_wrapper_py(
tf_gen_op_libs(
op_lib_names = ["dragnn_bulk_ops"],
deps = [":shape_helpers"],
)
tf_gen_op_wrapper_py(
......@@ -231,8 +247,10 @@ cc_library(
":compute_session_op",
":compute_session_pool",
":resource_container",
"//dragnn/protos:data_proto",
"//dragnn/protos:spec_proto",
":shape_helpers",
"//dragnn/core/util:label",
"//dragnn/protos:data_proto_cc",
"//dragnn/protos:spec_proto_cc",
"//syntaxnet:base",
"@org_tensorflow//third_party/eigen3",
],
......@@ -248,6 +266,8 @@ cc_library(
deps = [
":compute_session_op",
":resource_container",
":shape_helpers",
"//dragnn/core/util:label",
"//syntaxnet:base",
"@org_tensorflow//third_party/eigen3",
],
......@@ -269,8 +289,10 @@ tf_kernel_library(
":compute_session_op",
":compute_session_pool",
":resource_container",
"//dragnn/protos:data_proto",
"//dragnn/protos:spec_proto",
":shape_helpers",
"//dragnn/core/util:label",
"//dragnn/protos:data_proto_cc",
"//dragnn/protos:spec_proto_cc",
"//syntaxnet:base",
"@org_tensorflow//third_party/eigen3",
],
......@@ -289,8 +311,10 @@ tf_kernel_library(
":compute_session_op",
":compute_session_pool",
":resource_container",
":shape_helpers",
"//dragnn/components/util:bulk_feature_extractor",
"//dragnn/protos:spec_proto",
"//dragnn/core/util:label",
"//dragnn/protos:spec_proto_cc",
"//syntaxnet:base",
"@org_tensorflow//tensorflow/core:protos_all_cc",
"@org_tensorflow//third_party/eigen3",
......@@ -309,6 +333,7 @@ cc_test(
":resource_container",
"//dragnn/core/test:generic",
"//dragnn/core/test:mock_compute_session",
"//dragnn/core/util:label",
"//syntaxnet:base",
"//syntaxnet:test_main",
"@org_tensorflow//tensorflow/core:protos_all_cc",
......@@ -327,6 +352,7 @@ cc_test(
":resource_container",
"//dragnn/components/util:bulk_feature_extractor",
"//dragnn/core/test:mock_compute_session",
"//dragnn/core/util:label",
"//syntaxnet:base",
"//syntaxnet:test_main",
"@org_tensorflow//tensorflow/core/kernels:ops_testutil",
......
......@@ -22,6 +22,7 @@
#include "dragnn/core/index_translator.h"
#include "dragnn/core/input_batch_cache.h"
#include "dragnn/core/interfaces/component.h"
#include "dragnn/core/util/label.h"
#include "dragnn/protos/spec.pb.h"
#include "dragnn/protos/trace.pb.h"
......@@ -102,7 +103,7 @@ class ComputeSession {
const string &component_name, int channel_id) = 0;
// Get the oracle labels for the given component.
virtual std::vector<std::vector<int>> EmitOracleLabels(
virtual std::vector<std::vector<std::vector<Label>>> EmitOracleLabels(
const string &component_name) = 0;
// Returns true if the given component is terminal.
......@@ -126,6 +127,9 @@ class ComputeSession {
// bypassing de-serialization.
virtual void SetInputBatchCache(std::unique_ptr<InputBatchCache> batch) = 0;
// Returns the current InputBatchCache, or null if there is none.
virtual InputBatchCache *GetInputBatchCache() = 0;
// Resets all components owned by this ComputeSession.
virtual void ResetSession() = 0;
......
......@@ -18,6 +18,7 @@
#include <algorithm>
#include <utility>
#include "dragnn/core/util/label.h"
#include "dragnn/protos/data.pb.h"
#include "dragnn/protos/spec.pb.h"
#include "dragnn/protos/trace.pb.h"
......@@ -123,8 +124,12 @@ void ComputeSessionImpl::InitializeComponentData(const string &component_name,
VLOG(1) << "Source result found. Using prior initialization vector for "
<< component_name;
auto source = source_result->second;
CHECK(source->IsTerminal()) << "Source is not terminal for component '"
<< component_name << "'. Exiting.";
CHECK(source->IsTerminal())
<< "Source component '" << source->Name()
<< "' for currently active component '" << component_name
<< "' is not terminal. "
<< "Are you using bulk feature extraction with only linked features? "
<< "If so, consider using the StatelessComponent instead. Exiting.";
component->InitializeData(source->GetBeam(), max_beam_size,
input_data_.get());
}
......@@ -219,8 +224,8 @@ std::vector<LinkFeatures> ComputeSessionImpl::GetTranslatedLinkFeatures(
return features;
}
std::vector<std::vector<int>> ComputeSessionImpl::EmitOracleLabels(
const string &component_name) {
std::vector<std::vector<std::vector<Label>>>
ComputeSessionImpl::EmitOracleLabels(const string &component_name) {
return GetReadiedComponent(component_name)->GetOracleLabels();
}
......@@ -303,6 +308,10 @@ void ComputeSessionImpl::SetInputBatchCache(
input_data_ = std::move(batch);
}
InputBatchCache *ComputeSessionImpl::GetInputBatchCache() {
return input_data_.get();
}
void ComputeSessionImpl::ResetSession() {
// Reset all component states.
for (auto &component_pair : components_) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment