Unverified Commit 80178fc6 authored by Mark Omernick's avatar Mark Omernick Committed by GitHub
Browse files

Merge pull request #4153 from terryykoo/master

Export @195097388.
parents a84e1ef9 edea2b67
......@@ -16,20 +16,23 @@
#ifndef DRAGNN_CORE_COMPUTE_SESSION_IMPL_H_
#define DRAGNN_CORE_COMPUTE_SESSION_IMPL_H_
#include <map>
#include <memory>
#include "dragnn/components/util/bulk_feature_extractor.h"
#include "dragnn/core/compute_session.h"
#include "dragnn/core/index_translator.h"
#include "dragnn/core/input_batch_cache.h"
#include "dragnn/core/util/label.h"
#include "dragnn/protos/data.pb.h"
#include "dragnn/protos/spec.pb.h"
#include "dragnn/protos/trace.pb.h"
namespace syntaxnet {
namespace dragnn {
class ComputeSessionImpl : public ComputeSession {
class ComputeSessionImpl final : public ComputeSession {
public:
// Creates a ComputeSessionImpl with the provided component builder function.
ComputeSessionImpl(
......@@ -77,7 +80,7 @@ class ComputeSessionImpl : public ComputeSession {
std::vector<LinkFeatures> GetTranslatedLinkFeatures(
const string &component_name, int channel_id) override;
std::vector<std::vector<int>> EmitOracleLabels(
std::vector<std::vector<std::vector<Label>>> EmitOracleLabels(
const string &component_name) override;
bool IsTerminal(const string &component_name) override;
......@@ -92,6 +95,8 @@ class ComputeSessionImpl : public ComputeSession {
void SetInputBatchCache(std::unique_ptr<InputBatchCache> batch) override;
InputBatchCache *GetInputBatchCache() override;
void ResetSession() override;
void SetTracing(bool tracing_on) override;
......@@ -108,6 +113,11 @@ class ComputeSessionImpl : public ComputeSession {
Component *GetReadiedComponent(const string &component_name) const override;
private:
// Mapping from Keys to Values.
template <class Key, class Value>
using Mapping = std::map<Key, Value>;
// Get a given component. Fails if the component is not found.
Component *GetComponent(const string &component_name) const;
......@@ -124,11 +134,11 @@ class ComputeSessionImpl : public ComputeSession {
// Holds all of the components owned by this ComputeSession, associated with
// their names in the MasterSpec.
std::map<string, std::unique_ptr<Component>> components_;
Mapping<string, std::unique_ptr<Component>> components_;
// Holds a vector of translators for each component, indexed by the name
// of the component they belong to.
std::map<string, std::vector<IndexTranslator *>> translators_;
Mapping<string, std::vector<IndexTranslator *>> translators_;
// Holds ownership of all the IndexTranslators for this compute session.
std::vector<std::unique_ptr<IndexTranslator>> owned_translators_;
......@@ -136,7 +146,7 @@ class ComputeSessionImpl : public ComputeSession {
// The predecessor component for every component.
// If a component is not in this map, it has no predecessor component and
// will have its beam initialized without any data from other components.
std::map<Component *, Component *> predecessors_;
Mapping<Component *, Component *> predecessors_;
// Holds the current input data for this ComputeSession.
std::unique_ptr<InputBatchCache> input_data_;
......
......@@ -25,240 +25,49 @@
#include "dragnn/core/input_batch_cache.h"
#include "dragnn/core/interfaces/component.h"
#include "dragnn/core/interfaces/input_batch.h"
#include "dragnn/core/test/fake_component_base.h"
#include "dragnn/core/test/generic.h"
#include "dragnn/core/test/mock_component.h"
#include "dragnn/core/test/mock_transition_state.h"
#include "dragnn/core/util/label.h"
#include "tensorflow/core/platform/test.h"
namespace syntaxnet {
namespace dragnn {
using syntaxnet::test::EqualsProto;
using testing::_;
using testing::ElementsAre;
using testing::Return;
using testing::NotNull;
using testing::Return;
using testing::_;
// *****************************************************************************
// Test-internal class definitions.
// *****************************************************************************
// Define a test component to validate registered construction.
class TestComponentType1 : public Component {
class TestComponentType1 : public FakeComponentBase {
public:
TestComponentType1() {}
void InitializeComponent(const ComponentSpec &spec) override {
name_ = spec.name();
}
void InitializeData(
const std::vector<std::vector<const TransitionState *>> &states,
int max_beam_size, InputBatchCache *input_data) override {}
void InitializeTracing() override {}
void DisableTracing() override {}
bool IsReady() const override { return true; }
string Name() const override { return name_; }
int BeamSize() const override { return 3; }
int BatchSize() const override { return 1; }
int StepsTaken(int batch_index) const override { return 0; }
int GetBeamIndexAtStep(int step, int current_index,
int batch) const override {
return 0;
}
int GetSourceBeamIndex(int current_index, int batch) const override {
return 0;
}
bool AdvanceFromPrediction(const float *score_matrix, int num_items,
int num_actions) override {
return true;
}
void AdvanceFromOracle() override {}
bool IsTerminal() const override { return true; }
std::function<int(int, int, int)> GetStepLookupFunction(
const string &method) override {
return nullptr;
}
std::vector<std::vector<const TransitionState *>> GetBeam() override {
std::vector<std::vector<const TransitionState *>> states;
return states;
}
int GetFixedFeatures(std::function<int32 *(int)> allocate_indices,
std::function<int64 *(int)> allocate_ids,
std::function<float *(int)> allocate_weights,
int channel_id) const override {
return 0;
}
void BulkEmbedFixedFeatures(
int batch_size_padding, int num_steps_padding, int embedding_size,
const vector<const float *> &per_channel_embeddings,
float *embedding_output) override {}
int BulkGetFixedFeatures(const BulkFeatureExtractor &extractor) override {
return 0;
}
std::vector<LinkFeatures> GetRawLinkFeatures(int channel_id) const override {
std::vector<LinkFeatures> ret;
return ret;
}
std::vector<std::vector<int>> GetOracleLabels() const override {
std::vector<std::vector<int>> ret;
return ret;
}
void FinalizeData() override {}
void ResetComponent() override {}
std::vector<std::vector<ComponentTrace>> GetTraceProtos() const override {
std::vector<std::vector<ComponentTrace>> ret;
return ret;
}
void AddTranslatedLinkFeaturesToTrace(
const std::vector<LinkFeatures> &features, int channel_id) override {}
string name_;
};
REGISTER_DRAGNN_COMPONENT(TestComponentType1);
// Define a second test component to validate registered construction.
class TestComponentType2 : public Component {
class TestComponentType2 : public FakeComponentBase {
public:
TestComponentType2() {}
void InitializeComponent(const ComponentSpec &spec) override {
name_ = spec.name();
}
void InitializeData(
const std::vector<std::vector<const TransitionState *>> &states,
int max_beam_size, InputBatchCache *input_data) override {}
void InitializeTracing() override {}
void DisableTracing() override {}
bool IsReady() const override { return true; }
string Name() const override { return name_; }
int BeamSize() const override { return 4; }
int BatchSize() const override { return 2; }
int StepsTaken(int batch_index) const override { return 0; }
int GetBeamIndexAtStep(int step, int current_index,
int batch) const override {
return 0;
}
int GetSourceBeamIndex(int current_index, int batch) const override {
return 0;
}
bool AdvanceFromPrediction(const float *score_matrix, int num_items,
int num_actions) override {
return true;
}
void AdvanceFromOracle() override {}
bool IsTerminal() const override { return true; }
std::function<int(int, int, int)> GetStepLookupFunction(
const string &method) override {
return nullptr;
}
std::vector<std::vector<const TransitionState *>> GetBeam() override {
std::vector<std::vector<const TransitionState *>> states;
return states;
}
int GetFixedFeatures(std::function<int32 *(int)> allocate_indices,
std::function<int64 *(int)> allocate_ids,
std::function<float *(int)> allocate_weights,
int channel_id) const override {
return 0;
}
void BulkEmbedFixedFeatures(
int batch_size_padding, int num_steps_padding, int embedding_size,
const vector<const float *> &per_channel_embeddings,
float *embedding_output) override {}
int BulkGetFixedFeatures(const BulkFeatureExtractor &extractor) override {
return 0;
}
std::vector<LinkFeatures> GetRawLinkFeatures(int channel_id) const override {
std::vector<LinkFeatures> ret;
return ret;
}
std::vector<std::vector<int>> GetOracleLabels() const override {
std::vector<std::vector<int>> ret;
return ret;
}
void FinalizeData() override {}
void ResetComponent() override {}
std::vector<std::vector<ComponentTrace>> GetTraceProtos() const override {
std::vector<std::vector<ComponentTrace>> ret;
return ret;
}
void AddTranslatedLinkFeaturesToTrace(
const std::vector<LinkFeatures> &features, int channel_id) override {}
string name_;
};
REGISTER_DRAGNN_COMPONENT(TestComponentType2);
// Define a component that returns false for IsReady and IsTerminal.
class UnreadyComponent : public Component {
class UnreadyComponent : public FakeComponentBase {
public:
UnreadyComponent() {}
void InitializeComponent(const ComponentSpec &spec) override {
name_ = spec.name();
}
void InitializeData(
const std::vector<std::vector<const TransitionState *>> &states,
int max_beam_size, InputBatchCache *input_data) override {}
void InitializeTracing() override {}
void DisableTracing() override {}
bool IsReady() const override { return false; }
string Name() const override { return name_; }
int BeamSize() const override { return 1; }
int BatchSize() const override { return 2; }
int StepsTaken(int batch_index) const override { return 0; }
int GetBeamIndexAtStep(int step, int current_index,
int batch) const override {
return 0;
}
int GetSourceBeamIndex(int current_index, int batch) const override {
return 0;
}
bool AdvanceFromPrediction(const float *score_matrix, int num_items,
int num_actions) override {
return true;
}
void BulkEmbedFixedFeatures(
int batch_size_padding, int num_steps_padding, int embedding_size,
const vector<const float *> &per_channel_embeddings,
float *embedding_output) override {}
void AdvanceFromOracle() override {}
bool IsTerminal() const override { return false; }
std::function<int(int, int, int)> GetStepLookupFunction(
const string &method) override {
return nullptr;
}
std::vector<std::vector<const TransitionState *>> GetBeam() override {
std::vector<std::vector<const TransitionState *>> states;
return states;
}
int GetFixedFeatures(std::function<int32 *(int)> allocate_indices,
std::function<int64 *(int)> allocate_ids,
std::function<float *(int)> allocate_weights,
int channel_id) const override {
return 0;
}
int BulkGetFixedFeatures(const BulkFeatureExtractor &extractor) override {
return 0;
}
std::vector<LinkFeatures> GetRawLinkFeatures(int channel_id) const override {
std::vector<LinkFeatures> ret;
return ret;
}
std::vector<std::vector<int>> GetOracleLabels() const override {
std::vector<std::vector<int>> ret;
return ret;
}
void FinalizeData() override {}
void ResetComponent() override {}
std::vector<std::vector<ComponentTrace>> GetTraceProtos() const override {
std::vector<std::vector<ComponentTrace>> ret;
return ret;
}
void AddTranslatedLinkFeaturesToTrace(
const std::vector<LinkFeatures> &features, int channel_id) override {}
string name_;
};
REGISTER_DRAGNN_COMPONENT(UnreadyComponent);
......@@ -850,7 +659,7 @@ TEST(ComputeSessionImplTest,
// The death expectation is interacting strangely with this test, so I need
// to wrap the function in a lambda.
EXPECT_DEATH(function_that_will_die(), "Source is not terminal");
EXPECT_DEATH(function_that_will_die(), "is not terminal");
}
TEST(ComputeSessionImplTest, ResetSessionResetsAllComponents) {
......@@ -1147,7 +956,10 @@ TEST(ComputeSessionImplTest, InterfacePassesThrough) {
session->BulkEmbedFixedFeatures("component_one", 1, 2, 3, {nullptr}, nullptr);
// EmitOracleLabels()
std::vector<std::vector<int>> oracle_labels = {{0, 1}, {2, 3}};
// The size of oracle_labels is batch_size * beam_size * num_labels.
const std::vector<std::vector<std::vector<Label>>> oracle_labels{
{{{0, 1.f}}, {{1, 1.f}}}, {{{2, 1.f}}, {{3, 1.f}}}};
EXPECT_CALL(*mock_components["component_one"], GetOracleLabels())
.WillOnce(Return(oracle_labels));
EXPECT_EQ(oracle_labels, session->EmitOracleLabels("component_one"));
......@@ -1227,5 +1039,29 @@ TEST(ComputeSessionImplTest, SetInputBatchCache) {
EXPECT_EQ(session->GetSerializedPredictions(), data);
}
TEST(ComputeSessionImplTest, GetInputBatchCache) {
// Use empty protos since we won't interact with components.
MasterSpec spec;
GridPoint hyperparams;
ComputeSessionPool pool(spec, hyperparams);
auto session = pool.GetSession();
// No input data yet.
EXPECT_EQ(session->GetInputBatchCache(), nullptr);
// Set some data, expect some batch to be returned.
session->SetInputData({"arbitrary_data"});
EXPECT_NE(session->GetInputBatchCache(), nullptr);
// Create a dummy batch.
const std::vector<string> data = {"foo", "bar", "baz"};
std::unique_ptr<InputBatchCache> input_batch_cache(new InputBatchCache(data));
InputBatchCache *input_batch_cache_ptr = input_batch_cache.get();
// Inject a batch, expect that batch to be returned.
session->SetInputBatchCache(std::move(input_batch_cache));
EXPECT_EQ(session->GetInputBatchCache(), input_batch_cache_ptr);
}
} // namespace dragnn
} // namespace syntaxnet
......@@ -33,9 +33,9 @@ ComputeSessionPool::ComputeSessionPool(const MasterSpec &master_spec,
num_unique_sessions_(0) {
// Create a default component builder function. This function looks up
// components in the component registry and returns them.
component_builder_ = [](
const string &component_name,
const string &backend_type) -> std::unique_ptr<Component> {
component_builder_ =
[](const string &component_name,
const string &backend_type) -> std::unique_ptr<Component> {
VLOG(2) << "Creating component " << component_name << " with backend "
<< backend_type;
std::unique_ptr<Component> component(Component::Create(backend_type));
......@@ -45,7 +45,7 @@ ComputeSessionPool::ComputeSessionPool(const MasterSpec &master_spec,
// Create a default session builder function. This function returns a
// ComputeSessionImpl that uses the currently set component_builder_
// function to create its components.
session_builder_ = [this]() {
session_builder_ = [this]() EXCLUSIVE_LOCKS_REQUIRED(lock_) {
return std::unique_ptr<ComputeSession>(
new ComputeSessionImpl(num_unique_sessions_, this->component_builder_));
};
......@@ -75,20 +75,28 @@ void ComputeSessionPool::SetComponentBuilder(
}
std::unique_ptr<ComputeSession> ComputeSessionPool::GetSession() {
mutex_lock lock(lock_);
std::unique_ptr<ComputeSession> session_ptr;
if (sessions_.empty()) {
// There are no available sessions, so create and initialize one.
bool is_new = false;
{
// This mutex effectively single-threads the application at this point,
// since all ComputeSessions must call here; to minimize impact, we
// subscope it.
mutex_lock lock(lock_);
if (!sessions_.empty()) {
VLOG(2) << "Reusing session from pool of size " << sessions_.size();
session_ptr = std::move(sessions_.back());
sessions_.pop_back();
} else {
session_ptr = session_builder_();
is_new = true;
num_unique_sessions_++;
}
}
if (is_new) {
VLOG(2) << "Creating new session.";
session_ptr = session_builder_();
num_unique_sessions_++;
session_ptr->Init(master_spec_, hyperparams_);
} else {
// Get the last free session, and remove it from the free sessions vector.
VLOG(2) << "Reusing session from pool of size " << sessions_.size();
session_ptr = std::move(sessions_.back());
sessions_.pop_back();
session_ptr->ResetSession();
}
return session_ptr;
......
......@@ -21,6 +21,7 @@
#include "dragnn/core/compute_session.h"
#include "dragnn/protos/spec.pb.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/thread_annotations.h"
namespace syntaxnet {
namespace dragnn {
......@@ -50,7 +51,10 @@ class ComputeSessionPool {
}
// Returns the number of unique sessions that have been created.
int num_unique_sessions() { return num_unique_sessions_; }
int num_unique_sessions() {
tensorflow::mutex_lock lock(lock_);
return num_unique_sessions_;
}
// Returns a reference to the underlying spec for this pool.
const MasterSpec &GetSpec() const { return master_spec_; }
......@@ -82,21 +86,22 @@ class ComputeSessionPool {
const GridPoint hyperparams_;
// The function that is used to create ComputeSessions.
std::function<std::unique_ptr<ComputeSession>()> session_builder_;
std::function<std::unique_ptr<ComputeSession>()> session_builder_
GUARDED_BY(lock_);
// The function passed to ComputeSessions that will be used by that session
// to create components.
std::function<std::unique_ptr<Component>(const string &component_name,
const string &backend_type)>
component_builder_;
component_builder_ GUARDED_BY(lock_);
// ComputeSessions that are not currently being used. These sessions are not
// reset until they are requested by another thread.
std::vector<std::unique_ptr<ComputeSession>> sessions_;
std::vector<std::unique_ptr<ComputeSession>> sessions_ GUARDED_BY(lock_);
// Count of the number of unique ComputeSession objects that have been
// created. Used to assign IDs to new Sessions.
int num_unique_sessions_;
int num_unique_sessions_ GUARDED_BY(lock_);
// Mutex that protects accesses to all members of this object.
tensorflow::mutex lock_;
......
......@@ -33,7 +33,7 @@ IndexTranslator::IndexTranslator(const std::vector<Component *> &path,
} else if (method_ == "history") {
// History lookup: Return the number of steps taken less the feature.
step_lookup_ = [this](int batch_index, int beam_index, int feature) {
if (feature > path_.back()->StepsTaken(batch_index) - 1) {
if (feature > path_.back()->StepsTaken(batch_index) - 1 || feature < 0) {
VLOG(2) << "Translation to outside: feature is " << feature
<< " and steps_taken is "
<< path_.back()->StepsTaken(batch_index);
......
......@@ -16,8 +16,9 @@ cc_library(
":transition_state",
"//dragnn/components/util:bulk_feature_extractor",
"//dragnn/core:input_batch_cache",
"//dragnn/protos:spec_proto",
"//dragnn/protos:trace_proto",
"//dragnn/core/util:label",
"//dragnn/protos:spec_proto_cc",
"//dragnn/protos:trace_proto_cc",
"//syntaxnet:base",
"//syntaxnet:registry",
],
......
......@@ -21,6 +21,7 @@
#include "dragnn/components/util/bulk_feature_extractor.h"
#include "dragnn/core/input_batch_cache.h"
#include "dragnn/core/interfaces/transition_state.h"
#include "dragnn/core/util/label.h"
#include "dragnn/protos/spec.pb.h"
#include "dragnn/protos/trace.pb.h"
#include "syntaxnet/registry.h"
......@@ -120,6 +121,18 @@ class Component : public RegisterableClass<Component> {
const vector<const float *> &per_channel_embeddings,
float *embedding_output) = 0;
// Directly computes the embedding matrix for all channels, advancing the
// component via the oracle until it is terminal. This call takes a vector
// of float matrices containing embeddings, one per channel, in channel order.
// This function outputs a densified right-ragged tensor.
virtual void BulkEmbedDenseFixedFeatures(
const vector<const float *> &per_channel_embeddings,
float *embedding_output, int embedding_output_size,
int32 *offset_array_output, int offset_array_size) = 0;
// Gets the expected size of the data matrix for BulkEmbedDenseFixedFeatures.
virtual int BulkDenseFeatureSize() const = 0;
// Extracts and returns the vector of LinkFeatures for the specified
// channel. Note: these are NOT translated.
virtual std::vector<LinkFeatures> GetRawLinkFeatures(
......@@ -127,7 +140,8 @@ class Component : public RegisterableClass<Component> {
// Returns a vector of oracle labels for each element in the beam and
// batch.
virtual std::vector<std::vector<int>> GetOracleLabels() const = 0;
virtual std::vector<std::vector<std::vector<Label>>> GetOracleLabels()
const = 0;
// Annotate the underlying data object with the results of this Component's
// calculation.
......
......@@ -29,8 +29,8 @@ namespace dragnn {
// another, and every backend should define one. Note that inheriting from
// TransitionState directly is not sufficient to use the Beam class, which
// requires extra functionality given by inheriting from the
// ClonableTransitionState interface. (ClonableTransitionState is a subclass
// of TransitionState, so inheriting from ClonableTransitionState is sufficient
// CloneableTransitionState interface. (CloneableTransitionState is a subclass
// of TransitionState, so inheriting from CloneableTransitionState is sufficient
// to allow Components to pass your backing states.)
class TransitionState {
......
......@@ -62,6 +62,10 @@ void ComputeSessionOp::Compute(OpKernelContext *context) {
"Must declare at least one output of type string "
"for the ComputeSession handle if OutputsHandle is true."));
}
OP_REQUIRES(
context, context->input(0).dims() == 1,
InvalidArgument("Input to ComputeSession must be a vector, got rank ",
context->input(0).dims()));
// Gets the relevant ComputeSessionResource and computes with it.
auto handle = context->input(0).vec<string>();
......
......@@ -20,6 +20,7 @@
#include "dragnn/core/ops/compute_session_op.h"
#include "dragnn/core/resource_container.h"
#include "dragnn/core/util/label.h"
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/numeric_types.h"
#include "tensorflow/core/framework/op_kernel.h"
......@@ -40,10 +41,10 @@ using tensorflow::DataType;
using tensorflow::OpKernel;
using tensorflow::OpKernelConstruction;
using tensorflow::OpKernelContext;
using tensorflow::quint8;
using tensorflow::Status;
using tensorflow::Tensor;
using tensorflow::TensorShape;
using tensorflow::quint8;
using tensorflow::uint8;
namespace syntaxnet {
......@@ -335,11 +336,19 @@ class BulkEmbedFixedFeatures : public ComputeSessionOp {
embeddings[channel] =
context->input(embeddings_index).flat<float>().data();
}
int batch_size;
if (pad_to_batch_ == -1) {
batch_size = session->BatchSize(component_name());
} else {
batch_size = pad_to_batch_;
}
VLOG(2) << "batch size: " << batch_size;
Tensor *embedding_vectors;
OP_REQUIRES_OK(context,
context->allocate_output(
1,
TensorShape({pad_to_steps_ * pad_to_batch_ *
TensorShape({pad_to_steps_ * batch_size *
session->BeamSize(component_name()),
embedding_size}),
&embedding_vectors));
......@@ -348,8 +357,8 @@ class BulkEmbedFixedFeatures : public ComputeSessionOp {
&num_steps_tensor));
embedding_vectors->flat<float>().setZero();
int output_size = embedding_vectors->NumElements();
session->BulkEmbedFixedFeatures(component_name(), pad_to_batch_,
pad_to_steps_, output_size, embeddings,
session->BulkEmbedFixedFeatures(component_name(), batch_size, pad_to_steps_,
output_size, embeddings,
embedding_vectors->flat<float>().data());
num_steps_tensor->scalar<int32>()() = pad_to_steps_;
}
......@@ -370,6 +379,74 @@ class BulkEmbedFixedFeatures : public ComputeSessionOp {
REGISTER_KERNEL_BUILDER(Name("BulkEmbedFixedFeatures").Device(DEVICE_CPU),
BulkEmbedFixedFeatures);
// See docstring in dragnn_bulk_ops.cc.
class BulkEmbedDenseFixedFeatures : public ComputeSessionOp {
public:
explicit BulkEmbedDenseFixedFeatures(OpKernelConstruction *context)
: ComputeSessionOp(context) {
OP_REQUIRES_OK(context, context->GetAttr("num_channels", &num_channels_));
// The input vector's zeroth element is the state handle, and the remaining
// num_channels_ elements are tensors of float embeddings, one per channel.
std::vector<DataType> input_types(num_channels_ + 1, DT_FLOAT);
input_types[0] = DT_STRING;
const std::vector<DataType> output_types = {DT_STRING, DT_FLOAT, DT_INT32};
OP_REQUIRES_OK(context, context->MatchSignature(input_types, output_types));
}
bool OutputsHandle() const override { return true; }
bool RequiresComponentName() const override { return true; }
void ComputeWithState(OpKernelContext *context,
ComputeSession *session) override {
const auto &spec = session->Spec(component_name());
int embedding_size = 0;
std::vector<const float *> embeddings(num_channels_);
for (int channel = 0; channel < num_channels_; ++channel) {
const int embeddings_index = channel + 1;
embedding_size += context->input(embeddings_index).shape().dim_size(1) *
spec.fixed_feature(channel).size();
embeddings[channel] =
context->input(embeddings_index).flat<float>().data();
}
auto component = session->GetReadiedComponent(component_name());
int data_tensor_size = component->BulkDenseFeatureSize();
Tensor *embedding_vectors;
OP_REQUIRES_OK(context,
context->allocate_output(
1, TensorShape({data_tensor_size, embedding_size}),
&embedding_vectors));
Tensor *offset_array_tensor;
OP_REQUIRES(context, component->BeamSize() == 1,
tensorflow::errors::FailedPrecondition("Beam must be 1."));
OP_REQUIRES_OK(context, context->allocate_output(
2, TensorShape({component->BatchSize() + 1}),
&offset_array_tensor));
embedding_vectors->flat<float>().setZero();
int output_size = embedding_vectors->NumElements();
int offset_array_size = offset_array_tensor->NumElements();
component->BulkEmbedDenseFixedFeatures(
embeddings, embedding_vectors->flat<float>().data(), output_size,
offset_array_tensor->flat<int32>().data(), offset_array_size);
}
private:
// Number of fixed feature channels.
int num_channels_;
// Will pad output to this many batch elements.
int pad_to_batch_;
// Will pad output to this many steps.
int pad_to_steps_;
TF_DISALLOW_COPY_AND_ASSIGN(BulkEmbedDenseFixedFeatures);
};
REGISTER_KERNEL_BUILDER(Name("BulkEmbedDenseFixedFeatures").Device(DEVICE_CPU),
BulkEmbedDenseFixedFeatures);
// See docstring in dragnn_bulk_ops.cc.
class BulkAdvanceFromOracle : public ComputeSessionOp {
public:
......@@ -388,7 +465,9 @@ class BulkAdvanceFromOracle : public ComputeSessionOp {
const int batch_size = session->BatchSize(component_name());
const int beam_size = session->BeamSize(component_name());
const int num_items = batch_size * beam_size;
vector<vector<vector<int32>>> gold;
// Nested vector of size step_count * batch_size * beam_size * label_count.
vector<vector<vector<vector<Label>>>> gold;
int num_steps = 0;
while (!session->IsTerminal(component_name())) {
......@@ -408,8 +487,12 @@ class BulkAdvanceFromOracle : public ComputeSessionOp {
for (int batch_ix = 0; batch_ix < batch_size; ++batch_ix) {
for (int beam_ix = 0; beam_ix < beam_size; ++beam_ix, ++item) {
for (int step = 0; step < num_steps; ++step) {
// The default transition system behavior is a one-hot multi-class
// prediction, so there is only one gold label. If there are more than
// one gold labels, the code assumes they are equally valid, and we
// arbitrarily pick the first one.
gold_output->vec<int32>()(item * num_steps + step) =
step < gold.size() ? gold[step][batch_ix][beam_ix] : -1;
step < gold.size() ? gold[step][batch_ix][beam_ix][0].id : -1;
}
}
}
......
......@@ -17,6 +17,7 @@
#include "dragnn/core/compute_session_pool.h"
#include "dragnn/core/resource_container.h"
#include "dragnn/core/test/mock_compute_session.h"
#include "dragnn/core/util/label.h"
#include <gmock/gmock.h>
#include "tensorflow/core/framework/fake_input.h"
......@@ -624,13 +625,16 @@ TEST_F(DragnnBulkOpKernelsTest, BulkAdvanceFromOracle) {
.WillOnce(Return(true));
EXPECT_CALL(*mock_session, AdvanceFromOracle(kComponentName))
.Times(kNumSteps);
const vector<vector<vector<int32>>> gold = {
{{1}, {1}, {1}}, {{2}, {2}, {2}}, {{3}, {3}, {3}},
const std::vector<std::vector<std::vector<std::vector<Label>>>> gold_labels{
{{{{1, 1.f}}}, {{{1, 1.f}}}, {{{1, 1.f}}}},
{{{{2, 1.f}}}, {{{2, 1.f}}}, {{{2, 1.f}}}},
{{{{3, 1.f}}}, {{{3, 1.f}}}, {{{3, 1.f}}}},
};
EXPECT_CALL(*mock_session, EmitOracleLabels(kComponentName))
.WillOnce(Return(gold[0]))
.WillOnce(Return(gold[1]))
.WillOnce(Return(gold[2]));
.WillOnce(Return(gold_labels[0]))
.WillOnce(Return(gold_labels[1]))
.WillOnce(Return(gold_labels[2]));
EXPECT_CALL(*mock_session, BeamSize(kComponentName)).WillOnce(Return(1));
EXPECT_CALL(*mock_session, BatchSize(kComponentName))
.WillOnce(Return(kNumItems));
......
......@@ -13,6 +13,7 @@
// limitations under the License.
// =============================================================================
#include "dragnn/core/ops/shape_helpers.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/shape_inference.h"
......@@ -28,6 +29,15 @@ REGISTER_OP("BulkFixedFeatures")
.Output("num_steps: int32")
.Attr("component: string")
.Attr("num_channels: int")
.SetShapeFn([](tensorflow::shape_inference::InferenceContext *context) {
int num_channels;
TF_RETURN_IF_ERROR(context->GetAttr("num_channels", &num_channels));
for (int i = 1; i <= 3 * num_channels; ++i) {
VectorOutputShape(i, context);
}
ScalarOutputShape(3 * num_channels + 1, context);
return ComputeSessionHandleInputAndOutputShape(context);
})
.Doc(R"doc(
Given a ComputeSession and a component, outputs fixed features for all steps.
......@@ -60,6 +70,16 @@ REGISTER_OP("BulkFixedEmbeddings")
.Attr("pad_to_batch: int=-1")
.Attr("pad_to_steps: int=-1")
.SetIsStateful()
.SetShapeFn([](tensorflow::shape_inference::InferenceContext *context) {
int num_channels;
TF_RETURN_IF_ERROR(context->GetAttr("num_channels", &num_channels));
for (int i = 1; i <= num_channels; ++i) {
TF_RETURN_IF_ERROR(MatrixInputShape(i, context));
}
MatrixOutputShape(1, context);
ScalarOutputShape(2, context);
return ComputeSessionHandleInputAndOutputShape(context);
})
.Doc(R"doc(
This op is a more efficient version of BulkFixedFeatures.
......@@ -91,6 +111,16 @@ REGISTER_OP("BulkEmbedFixedFeatures")
.Attr("pad_to_batch: int")
.Attr("pad_to_steps: int")
.SetIsStateful()
.SetShapeFn([](tensorflow::shape_inference::InferenceContext *context) {
int num_channels;
TF_RETURN_IF_ERROR(context->GetAttr("num_channels", &num_channels));
for (int i = 1; i <= num_channels; ++i) {
TF_RETURN_IF_ERROR(MatrixInputShape(i, context));
}
MatrixOutputShape(1, context);
ScalarOutputShape(2, context);
return ComputeSessionHandleInputAndOutputShape(context);
})
.Doc(R"doc(
This op is a more efficient version of BulkFixedFeatures.
......@@ -112,11 +142,55 @@ pad_to_batch: The op will pad/truncate to this number of elements.
pad_to_steps: The op will pad/truncate to this number of steps.
)doc");
REGISTER_OP("BulkEmbedDenseFixedFeatures")
.Input("handle: string")
.Input("embedding_matrix: num_channels * float")
.Output("output_handle: string")
.Output("embedding_vectors: float")
.Output("offset_array: int32")
.Attr("component: string")
.Attr("num_channels: int")
.SetIsStateful()
.SetShapeFn([](tensorflow::shape_inference::InferenceContext *context) {
int num_channels;
TF_RETURN_IF_ERROR(context->GetAttr("num_channels", &num_channels));
for (int i = 1; i <= num_channels; ++i) {
TF_RETURN_IF_ERROR(MatrixInputShape(i, context));
}
MatrixOutputShape(1, context);
VectorOutputShape(2, context);
return ComputeSessionHandleInputAndOutputShape(context);
})
.Doc(R"doc(
This op is a more efficient version of BulkFixedFeatures.
It is intended to be run with large batch sizes at inference time. The op takes
a handle to ComputeSession and embedding matrices as tensor inputs, and directly
outputs concatenated embedding vectors. It calls the BulkEmbedFixedFeatures
method on the underlying component directly, so it requires a padding vector
to be passed.
handle: A handle to ComputeSession.
embedding_matrix: Embedding matrices.
output_handle: A handle to the same ComputeSession after advancement.
embedding_vectors: (matrix of float) Concatenated embeddings, in a dense
array.
offset_array: An array of integers representing the offset of each batch element
in the embedding_vectors array. It is of size (batch+1) and the last element is
the total size of the embedding array.
component: The name of a Component instance, matching the ComponentSpec.name.
num_channels: The number of FixedFeature channels.
)doc");
REGISTER_OP("BulkAdvanceFromOracle")
.Input("handle: string")
.Output("output_handle: string")
.Output("gold_labels: int32")
.Attr("component: string")
.SetShapeFn([](tensorflow::shape_inference::InferenceContext *context) {
VectorOutputShape(1, context);
return ComputeSessionHandleInputAndOutputShape(context);
})
.Doc(R"doc(
Given a ComputeSession, advances until all states are final.
......@@ -140,14 +214,9 @@ REGISTER_OP("BulkAdvanceFromPrediction")
.Output("output_handle: string")
.Attr("component: string")
.Attr("T: type")
.SetShapeFn([](tensorflow::shape_inference::InferenceContext *c) {
tensorflow::shape_inference::ShapeHandle handle;
TF_RETURN_IF_ERROR(c->Merge(c->input(0), c->Vector(2), &handle));
c->set_output(0, handle);
auto scores = c->input(1);
TF_RETURN_IF_ERROR(c->WithRank(scores, 2, &scores));
return tensorflow::Status::OK();
.SetShapeFn([](tensorflow::shape_inference::InferenceContext *context) {
TF_RETURN_IF_ERROR(MatrixInputShape(1, context));
return ComputeSessionHandleInputAndOutputShape(context);
})
.Doc(R"doc(
Given a ComputeSession and a tensor of scores, advances the state.
......
......@@ -21,6 +21,7 @@
#include "dragnn/core/compute_session_pool.h"
#include "dragnn/core/ops/compute_session_op.h"
#include "dragnn/core/resource_container.h"
#include "dragnn/core/util/label.h"
#include "dragnn/protos/data.pb.h"
#include "dragnn/protos/spec.pb.h"
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
......@@ -41,8 +42,6 @@ using tensorflow::DT_INT32;
using tensorflow::DT_INT64;
using tensorflow::DT_STRING;
using tensorflow::DataType;
using tensorflow::io::Dirname;
using tensorflow::io::JoinPath;
using tensorflow::OpKernel;
using tensorflow::OpKernelConstruction;
using tensorflow::OpKernelContext;
......@@ -50,6 +49,8 @@ using tensorflow::ResourceMgr;
using tensorflow::Status;
using tensorflow::Tensor;
using tensorflow::TensorShape;
using tensorflow::io::Dirname;
using tensorflow::io::JoinPath;
namespace syntaxnet {
namespace dragnn {
......@@ -330,6 +331,209 @@ class GetSessionCounts : public OpKernel {
REGISTER_KERNEL_BUILDER(Name("GetSessionCounts").Device(DEVICE_CPU),
GetSessionCounts);
// Rebatches a dense ragged tensor into a batch of padded subsequences.
class RebatchDensor : public OpKernel {
public:
explicit RebatchDensor(OpKernelConstruction *context) : OpKernel(context) {
OP_REQUIRES_OK(context,
context->GetAttr("sequence_length", &sequence_length_));
OP_REQUIRES_OK(context, context->GetAttr("lr_padding", &lr_padding_));
OP_REQUIRES_OK(context, context->MatchSignature({DT_FLOAT, DT_INT32},
{DT_FLOAT, DT_INT32}));
OP_REQUIRES(context, lr_padding_ < sequence_length_,
tensorflow::errors::FailedPrecondition(
"Sequence length must be longer than padding."));
}
void Compute(OpKernelContext *context) override {
// Figure out how many sequences we need.
const Tensor &data = context->input(0);
const int embedding_size = data.shape().dim_size(1);
const Tensor &offsets = context->input(1);
const int offsets_size = offsets.shape().dim_size(0);
const int batch_size = offsets_size - 1;
const auto &offset_data = offsets.vec<int32>();
int num_elements = 0;
for (int i = 0; i < batch_size; ++i) {
int element_length = offset_data(i + 1) - offset_data(i);
if (element_length > 0) {
int num_full_sequences = element_length / sequence_length_;
int length = ((element_length % sequence_length_) == 0)
? (num_full_sequences)
: (num_full_sequences + 1);
num_elements += length;
VLOG(2) << "Item " << i << " of length " << element_length
<< " will use " << length << ". Total: " << num_elements;
}
}
int output_sequence_length = 2 * lr_padding_ + sequence_length_;
VLOG(2) << "Rebatch shape: " << num_elements << " "
<< output_sequence_length << " " << embedding_size;
// Allocate the output tensors.
Tensor *output;
OP_REQUIRES_OK(
context,
context->allocate_output(
0,
TensorShape({num_elements, output_sequence_length, embedding_size}),
&output));
output->flat<float>().setZero();
Tensor *indices;
OP_REQUIRES_OK(context, context->allocate_output(
1, TensorShape({num_elements}), &indices));
const float *dense_data = data.flat<float>().data();
float *output_data = output->flat<float>().data();
int64 start_offset = lr_padding_ * embedding_size;
int64 seq_max_length = lr_padding_ + sequence_length_;
int64 row_index = 0;
for (int i = 0; i < batch_size; ++i) {
int64 element_length = offset_data(i + 1) - offset_data(i);
VLOG(2) << "Rebatching index " << i << " with size " << element_length;
if (element_length == 0) {
continue;
}
int64 first_seq_length = std::min(element_length, seq_max_length);
int64 subseqence_length = first_seq_length * embedding_size;
int64 dense_start = offset_data(i) * embedding_size;
int64 output_start =
row_index * output_sequence_length * embedding_size + start_offset;
for (int j = 0; j < subseqence_length; ++j) {
output_data[output_start + j] = dense_data[dense_start + j];
}
indices->vec<int32>()(row_index) = i;
VLOG(2) << "Rebatched " << i << " to " << row_index;
++row_index;
int64 tokens_remaining = element_length - sequence_length_;
VLOG(2) << "Remaining: " << tokens_remaining;
while (tokens_remaining > 0) {
int64 seq_length = std::min(tokens_remaining, seq_max_length);
int64 subseqence_length = (seq_length + lr_padding_) * embedding_size;
int64 data_start =
(offset_data(i + 1) - tokens_remaining) - lr_padding_;
int64 dense_start = data_start * embedding_size;
int64 output_start =
row_index * output_sequence_length * embedding_size;
for (int j = 0; j < subseqence_length; ++j) {
output_data[output_start + j] = dense_data[dense_start + j];
}
indices->vec<int32>()(row_index) = i;
VLOG(2) << "Rebatched " << i << " to " << row_index;
++row_index;
tokens_remaining -= sequence_length_;
VLOG(2) << "Remaining: " << tokens_remaining;
}
}
for (int j = 0; j < num_elements; ++j) {
VLOG(2) << "Rebatch item :" << j
<< " has index: " << indices->vec<int32>()(j);
}
}
private:
int sequence_length_;
int lr_padding_;
TF_DISALLOW_COPY_AND_ASSIGN(RebatchDensor);
};
REGISTER_KERNEL_BUILDER(Name("RebatchDensor").Device(DEVICE_CPU),
RebatchDensor);
// Rebatches a dense ragged tensor into a batch of padded subsequences.
class UnbatchSubsequences : public OpKernel {
public:
explicit UnbatchSubsequences(OpKernelConstruction *context)
: OpKernel(context) {
OP_REQUIRES_OK(context, context->MatchSignature(
{DT_FLOAT, DT_INT32, DT_INT32}, {DT_FLOAT}));
}
void Compute(OpKernelContext *context) override {
// Figure out how many sequences we need.
const Tensor &data = context->input(0);
const int input_batch_size = data.shape().dim_size(0);
const int sequence_length = data.shape().dim_size(2);
const int embedding_size = data.shape().dim_size(3);
const int input_size = data.NumElements();
const Tensor &indices = context->input(1);
const int indices_size = indices.shape().dim_size(0);
const Tensor &offsets = context->input(2);
const int offsets_size = offsets.shape().dim_size(0);
const int batch_size = offsets_size - 1;
const auto &offset_data = offsets.vec<int32>();
int max_sequence_size = 0;
for (int i = 0; i < batch_size; ++i) {
int element_length = offset_data(i + 1) - offset_data(i);
if (element_length > max_sequence_size) {
max_sequence_size = element_length;
}
}
// Allocate the output tensors.
Tensor *output;
VLOG(2) << "Unbatch shape: " << batch_size << " " << max_sequence_size
<< " " << embedding_size;
OP_REQUIRES_OK(
context,
context->allocate_output(
0, TensorShape({batch_size, max_sequence_size, embedding_size}),
&output));
output->flat<float>().setZero();
int output_size = output->NumElements();
const float *input_data = data.flat<float>().data();
float *output_data = output->flat<float>().data();
const int32 *index_data = indices.flat<int32>().data();
int previous_index = -1;
int current_sequence_element = 0;
VLOG(2) << "Sequence length: " << sequence_length;
VLOG(2) << "Indices size: " << indices_size;
for (int i = 0; i < indices_size; ++i) {
int current_index = index_data[i];
CHECK(current_index < input_batch_size) << "Index out of bounds.";
if (current_index > previous_index) {
previous_index = current_index;
current_sequence_element = 0;
}
int current_sequence_length = std::min(
sequence_length, max_sequence_size - current_sequence_element);
int64 input_offset = i * sequence_length * embedding_size;
int64 output_offset =
(current_index * max_sequence_size + current_sequence_element) *
embedding_size;
VLOG(2) << "cur_ind: " << current_index
<< " cur_element: " << current_sequence_element
<< " cur sqlen: " << current_sequence_length
<< " in: " << input_offset << " out: " << output_offset;
for (int j = 0; j < current_sequence_length * embedding_size; ++j) {
CHECK((output_offset + j) < output_size) << "output index invalid";
CHECK((input_offset + j) < input_size) << "input index invalid";
output_data[output_offset + j] = input_data[input_offset + j];
}
current_sequence_element += current_sequence_length;
}
}
private:
TF_DISALLOW_COPY_AND_ASSIGN(UnbatchSubsequences);
};
REGISTER_KERNEL_BUILDER(Name("UnbatchSubsequences").Device(DEVICE_CPU),
UnbatchSubsequences);
/*******************************************************************************
* ComputeSessionOps below here.
******************************************************************************/
......@@ -450,8 +654,8 @@ class ExtractFixedFeatures : public ComputeSessionOp {
component_name(), indices_allocator, ids_allocator, weights_allocator,
channel_id_);
VLOG(2) << "Extracted features (" << num_features << "): "
<< " ids=" << context->mutable_output(1)->vec<int64>()
<< " weights=" << context->mutable_output(2)->vec<float>()
<< " ids=" << context->mutable_output(1)->vec<int64>()
<< " weights=" << context->mutable_output(2)->vec<float>()
<< " indices=" << context->mutable_output(0)->vec<int32>();
}
......@@ -546,7 +750,8 @@ REGISTER_KERNEL_BUILDER(Name("ExtractLinkFeatures").Device(DEVICE_CPU),
// Given a handle to a BatchedBeamComponentState, emits a vector of gold
// labels.
// The vector of gold labels has size batch_size * beam_size.
// The vector of gold labels has size batch_size * beam_size. The code assumes
// one label per instance.
class EmitOracleLabels : public ComputeSessionOp {
public:
explicit EmitOracleLabels(OpKernelConstruction *context)
......@@ -567,12 +772,13 @@ class EmitOracleLabels : public ComputeSessionOp {
TensorShape({session->BatchSize(component_name()) *
session->BeamSize(component_name())}),
&output));
std::vector<std::vector<int>> batched_labels =
std::vector<std::vector<std::vector<Label>>> batched_labels =
session->EmitOracleLabels(component_name());
int raw_index = 0;
for (const auto &batch_vector : batched_labels) {
for (const auto &label : batch_vector) {
output->vec<int32>()(raw_index) = label;
for (const auto &instance_labels : batch_vector) {
// The code assumes there is one label per instance.
output->vec<int32>()(raw_index) = instance_labels.at(0).id;
++raw_index;
}
}
......@@ -585,6 +791,66 @@ class EmitOracleLabels : public ComputeSessionOp {
REGISTER_KERNEL_BUILDER(Name("EmitOracleLabels").Device(DEVICE_CPU),
EmitOracleLabels);
// Given a handle to a BatchedBeamComponentState, emits corresponding vectors of
// indices, gold labels, and probabilities. The size of the output vectors is
// equal to the sum of the number of labels for each instance in the beams in
// the batch.
class EmitOracleLabelsAndProbabilities : public ComputeSessionOp {
public:
explicit EmitOracleLabelsAndProbabilities(OpKernelConstruction *context)
: ComputeSessionOp(context) {
OP_REQUIRES_OK(context, context->MatchSignature(
{DT_STRING}, {DT_INT32, DT_INT32, DT_FLOAT}));
}
bool OutputsHandle() const override { return false; }
bool RequiresComponentName() const override { return true; }
void ComputeWithState(OpKernelContext *context,
ComputeSession *session) override {
const std::vector<std::vector<std::vector<Label>>> batched_labels =
session->EmitOracleLabels(component_name());
int label_count = 0;
for (const auto &beam : batched_labels) {
for (const auto &instance : beam) {
label_count += instance.size();
}
}
Tensor *indices_output;
OP_REQUIRES_OK(context,
context->allocate_output(0, TensorShape({label_count}),
&indices_output));
Tensor *label_output;
OP_REQUIRES_OK(context, context->allocate_output(
1, TensorShape({label_count}), &label_output));
Tensor *prob_output;
OP_REQUIRES_OK(context, context->allocate_output(
2, TensorShape({label_count}), &prob_output));
// Index keeping track of each instance in the beams in the batch.
int instance_index = -1;
int raw_index = -1;
for (const auto &beam : batched_labels) {
for (const auto &instance : beam) {
++instance_index;
for (const Label &label : instance) {
++raw_index;
indices_output->vec<int32>()(raw_index) = instance_index;
label_output->vec<int32>()(raw_index) = label.id;
prob_output->vec<float>()(raw_index) = label.probability;
}
}
}
}
private:
TF_DISALLOW_COPY_AND_ASSIGN(EmitOracleLabelsAndProbabilities);
};
REGISTER_KERNEL_BUILDER(
Name("EmitOracleLabelsAndProbabilities").Device(DEVICE_CPU),
EmitOracleLabelsAndProbabilities);
// Given a handle to a ComponentState, emits a single bool indicating
// whether all elements in the batch contain beams containing all final states.
class EmitAllFinal : public ComputeSessionOp {
......
......@@ -23,6 +23,7 @@
#include "dragnn/core/resource_container.h"
#include "dragnn/core/test/generic.h"
#include "dragnn/core/test/mock_compute_session.h"
#include "dragnn/core/util/label.h"
#include <gmock/gmock.h>
......@@ -44,26 +45,26 @@ namespace syntaxnet {
namespace dragnn {
using tensorflow::AllocatorAttributes;
using tensorflow::checkpoint::TensorSliceReaderCacheWrapper;
using tensorflow::DT_BOOL;
using tensorflow::DT_FLOAT;
using tensorflow::DT_STRING;
using tensorflow::DT_INT32;
using tensorflow::FrameAndIter;
using tensorflow::DT_STRING;
using tensorflow::DataType;
using tensorflow::FrameAndIter;
using tensorflow::NodeDefBuilder;
using tensorflow::OpKernelContext;
using tensorflow::ResourceMgr;
using tensorflow::ScopedStepContainer;
using tensorflow::Status;
using tensorflow::test::SetOutputAttrs;
using tensorflow::TensorShape;
using tensorflow::checkpoint::TensorSliceReaderCacheWrapper;
using tensorflow::test::SetOutputAttrs;
using testing::_;
using testing::ElementsAreArray;
using testing::Invoke;
using testing::Pointwise;
using testing::Return;
using testing::_;
typedef ResourceContainer<ComputeSession> ComputeSessionResource;
typedef ResourceContainer<ComputeSessionPool> ComputeSessionPoolResource;
......@@ -126,12 +127,18 @@ class TestComponent : public Component {
int batch_size_padding, int num_steps_padding, int output_array_size,
const vector<const float *> &per_channel_embeddings,
float *embedding_matrix) override {}
void BulkEmbedDenseFixedFeatures(
const vector<const float *> &per_channel_embeddings,
float *embedding_output, int embedding_output_size,
int *offset_array_output, int offset_array_size) override {}
int BulkDenseFeatureSize() const override { return 0; }
std::vector<LinkFeatures> GetRawLinkFeatures(int channel_id) const override {
std::vector<LinkFeatures> ret;
return ret;
}
std::vector<std::vector<int>> GetOracleLabels() const override {
std::vector<std::vector<int>> ret;
std::vector<std::vector<std::vector<Label>>> GetOracleLabels()
const override {
std::vector<std::vector<std::vector<Label>>> ret;
return ret;
}
void FinalizeData() override {}
......@@ -482,6 +489,201 @@ TEST_F(DragnnOpKernelsTest, GetSessionCountsOpTest) {
GetOutput(0)->vec<int64>()(1));
}
// The RebatchDensor op should rebatch densors.
TEST_F(DragnnOpKernelsTest, RebatchDensorOpTest) {
int sequence_length = 3;
int pad_length = 2;
TF_ASSERT_OK(NodeDefBuilder("rebatch_densor", "RebatchDensor")
.Attr("sequence_length", sequence_length)
.Attr("lr_padding", pad_length)
.Input(FakeInput(DT_FLOAT)) // The dense data tensor.
.Input(FakeInput(DT_INT32)) // The offsets tensor.
.Finalize(node_def()));
TF_ASSERT_OK(InitOp());
// Set the input data.
const std::vector<float> weights = {
// PASSAGE 1
1.01, 1.02, //
1.04, 1.05, //
1.07, 1.08, //
1.10, 1.11, //
// PASSAGE 2
2.01, 2.02, //
2.03, 2.04, //
2.05, 2.06, //
2.07, 2.08, //
2.09, 2.10, //
2.11, 2.12 //
};
AddInputFromArray<float>(TensorShape({10, 2}), weights);
const std::vector<int> offsets = {0, 4, 10};
AddInputFromArray<int>(TensorShape({3}), offsets);
// Reset the test context to ensure it's clean.
ResetOpKernelContext();
// Run the kernel.
TF_EXPECT_OK(RunOpKernelWithContext());
// The first two embeddings in the 1st and 3rd output should be {0.0}
// The first two embeddings in the 2nd output should be embeddings from token
// 1 and 2 (so vector items 4 through 10).
// The last 2 embeddings in row 1 should be from token 4, then 0s.
// The last 4 embeddings in rows 2 and 3 should be 0.
const std::vector<float> expected_weights = {
// BATCH 0
0.0, 0.0, //
0.0, 0.0, //
1.01, 1.02, //
1.04, 1.05, //
1.07, 1.08, //
1.10, 1.11, //
0.0, 0.0, //
// BATCH 1
1.04, 1.05, //
1.07, 1.08, //
1.10, 1.11, //
0.0, 0.0, //
0.0, 0.0, //
0.0, 0.0, //
0.0, 0.0, //
// BATCH 2
0.0, 0.0, //
0.0, 0.0, //
2.01, 2.02, //
2.03, 2.04, //
2.05, 2.06, //
2.07, 2.08, //
2.09, 2.10, //
// BATCH 3
2.03, 2.04, //
2.05, 2.06, //
2.07, 2.08, //
2.09, 2.10, //
2.11, 2.12, //
0.0, 0.0, //
0.0, 0.0, //
};
for (int i = 0; i < expected_weights.size(); ++i) {
LOG(INFO) << GetOutput(0)->flat<float>()(i);
}
// The output should have dimensions {4, 7, 2}.
EXPECT_EQ(4, GetOutput(0)->dim_size(0));
EXPECT_EQ(7, GetOutput(0)->dim_size(1));
EXPECT_EQ(2, GetOutput(0)->dim_size(2));
// The output should match the expected tensor.
for (int i = 0; i < expected_weights.size(); ++i) {
EXPECT_EQ(expected_weights[i], GetOutput(0)->flat<float>()(i))
<< "Failed at index " << i;
}
// The offsets output shout have dimension {3}.
EXPECT_EQ(4, GetOutput(1)->dim_size(0));
std::vector<int> expected_indices = {0, 0, 1, 1};
for (int i = 0; i < expected_indices.size(); ++i) {
EXPECT_EQ(expected_indices[i], GetOutput(1)->flat<int32>()(i))
<< "Failed at index " << i;
}
}
// Todo(me): write this
TEST_F(DragnnOpKernelsTest, UnbatchSubsequences) {
TF_ASSERT_OK(NodeDefBuilder("unbatch_subsequences", "UnbatchSubsequences")
.Input(FakeInput(DT_FLOAT)) // The data tensor.
.Input(FakeInput(DT_INT32)) // The index tensor.
.Input(FakeInput(DT_INT32)) // The offsets tensor.
.Finalize(node_def()));
TF_ASSERT_OK(InitOp());
// Set the input data.
const std::vector<float> input = {
// BATCH 0
1.01, 1.02, //
1.04, 1.05, //
1.07, 1.08, //
1.10, 1.11, //
1.12, 1.13, //
// BATCH 1
1.14, 1.15, //
0.0, 0.0, //
0.0, 0.0, //
0.0, 0.0, //
0.0, 0.0, //
// BATCH 2
2.01, 2.02, //
0.0, 0.0, //
0.0, 0.0, //
0.0, 0.0, //
0.0, 0.0, //
// BATCH 3
3.01, 3.02, //
0.0, 0.0, //
0.0, 0.0, //
0.0, 0.0, //
0.0, 0.0 //
};
AddInputFromArray<float>(TensorShape({4, 1, 5, 2}), input);
const std::vector<int> indices = {0, 0, 1, 2};
AddInputFromArray<int>(TensorShape({4}), indices);
const std::vector<int> offsets = {0, 6, 7, 8};
AddInputFromArray<int>(TensorShape({4}), offsets);
// Reset the test context to ensure it's clean.
ResetOpKernelContext();
// Run the kernel.
TF_EXPECT_OK(RunOpKernelWithContext());
// The first two embeddings in the 1st and 3rd output should be {0.0}
// The first two embeddings in the 2nd output should be embeddings from token
// 1 and 2 (so vector items 4 through 10).
// The last 2 embeddings in row 1 should be from token 4, then 0s.
// The last 4 embeddings in rows 2 and 3 should be 0.
const std::vector<float> expected_weights = {
// BATCH 0
1.01, 1.02, //
1.04, 1.05, //
1.07, 1.08, //
1.10, 1.11, //
1.12, 1.13, //
1.14, 1.15, //
// BATCH 1
2.01, 2.02, //
0.0, 0.0, //
0.0, 0.0, //
0.0, 0.0, //
0.0, 0.0, //
0.0, 0.0, //
// BATCH 2
3.01, 3.02, //
0.0, 0.0, //
0.0, 0.0, //
0.0, 0.0, //
0.0, 0.0, //
0.0, 0.0 //
};
for (int i = 0; i < expected_weights.size(); ++i) {
LOG(INFO) << GetOutput(0)->flat<float>()(i);
}
// The output should have dimensions {3, 7, 2}.
EXPECT_EQ(3, GetOutput(0)->dim_size(0));
EXPECT_EQ(6, GetOutput(0)->dim_size(1));
EXPECT_EQ(2, GetOutput(0)->dim_size(2));
// The output should match the expected tensor.
for (int i = 0; i < expected_weights.size(); ++i) {
EXPECT_EQ(expected_weights[i], GetOutput(0)->flat<float>()(i))
<< "Failed at index " << i;
}
}
// The AdvanceFromOracle op should call AdvanceFromOracle on the specified
// component name.
TEST_F(DragnnOpKernelsTest, AdvanceFromOracleOpTest) {
......@@ -651,7 +853,8 @@ TEST_F(DragnnOpKernelsTest, ExtractFixedFeaturesOpTest) {
// If we have 3 features, for a given channel, we might have:
// feature a: (5, 1)
// feature b: (5, 0.5), (6, 0.7)
// feature c: (3, 0.1), (7, [empty]) <- Empty weights are equivalent to 1.0.
// feature c: (3, 0.1), (7, [empty]) <- Empty weights are equivalent
// to 1.0.
// In this case:
// indices should look like [0 , 1 , 1 , 2 , 2 ]
// ids should be [5 , 5 , 6 , 3 , 7 ]
......@@ -727,15 +930,15 @@ TEST_F(DragnnOpKernelsTest, ExtractLinkFeaturesOpTest) {
MockComputeSession *mock_session_ptr = mock_session.get();
// This op will return link features in two flat arrays using batch-major
// ordering. So, if we have a batch of 2 and a beam of 3, with data as follows
// (note that the features are {batch,beam,step} and [] is 'empty')
// ordering. So, if we have a batch of 2 and a beam of 3, with data as
// follows (note that the features are {batch,beam,step} and [] is 'empty')
// batch 1 features: {{02,03,[]},{01,00,04},{08,06,01}}
// batch 2 features: {{12,13,14},{11,12,-1},{18,16,20}}
//
// and a **source component** beam size of 5 should result in output tensors:
// step_idx (tensor 0): {-1, 4, 1, 14, -1, 20}
// array_idx (tensor 1): { 0, 5, 46, 73, 0, 106}
// (0 [step=-1]),(5=1*5+0),(46=8*5+6),(73=12*5+13),(0 [step=-1]),(96=18*5+16)
// and a **source component** beam size of 5 should result in output
// tensors: step_idx (tensor 0): {-1, 4, 1, 14, -1, 20} array_idx (tensor
// 1): { 0, 5, 46, 73, 0, 106} (0
// [step=-1]),(5=1*5+0),(46=8*5+6),(73=12*5+13),(0 [step=-1]),(96=18*5+16)
constexpr int kSourceComponentBeamSize = 5;
std::vector<LinkFeatures> features;
......@@ -814,8 +1017,11 @@ TEST_F(DragnnOpKernelsTest, EmitOracleLabelsOpTest) {
constexpr int kBatchSize = 2;
constexpr int kBeamSize = 4;
const std::vector<std::vector<int>> oracle_labels(
{{1, 3, 5, 7}, {2, 4, 6, 8}});
// Vectors containing, respectively, label ids and the corresponding Labels.
const std::vector<std::vector<std::vector<Label>>> oracle_labels(
{{{{1, 1.f}}, {{3, 1.f}}, {{5, 1.f}}, {{7, 1.f}}},
{{{2, 1.f}}, {{4, 1.f}}, {{6, 1.f}}, {{8, 1.f}}}});
EXPECT_CALL(*mock_session_ptr, BatchSize(component_name))
.WillRepeatedly(Return(kBatchSize));
......@@ -836,6 +1042,73 @@ TEST_F(DragnnOpKernelsTest, EmitOracleLabelsOpTest) {
}
}
// The EmitOracleLabelsAndProbabilities op returns vectors of instance
// indices, labels, and probabilities corresponding to the elements in the
// beams in the batch.
TEST_F(DragnnOpKernelsTest, EmitOracleLabelsAndProbabilitiesOpTest) {
// Create and initialize the kernel under test.
const string component_name = "TESTING_COMPONENT_NAME";
TF_ASSERT_OK(
NodeDefBuilder("emit_oracle_labels_and_probabilities",
"EmitOracleLabelsAndProbabilities")
.Attr("component", component_name)
.Input(FakeInput(DT_STRING)) // The handle for the ComputeSession.
.Finalize(node_def()));
TF_ASSERT_OK(InitOp());
// Set the input data.
const string container_string = "container_str";
const string id_string = "id_str";
AddInputFromList<string>(TensorShape({2}), {container_string, id_string});
// Reset the test context to ensure it's clean.
ResetOpKernelContext();
// Create a MockComputeSession and set expectations.
std::unique_ptr<MockComputeSession> mock_session(new MockComputeSession());
MockComputeSession *mock_session_ptr = mock_session.get();
// Wrap the ComputeSessionResource and put it into the resource manager.
TF_ASSERT_OK(resource_mgr()->Create<ComputeSessionResource>(
container_string, id_string,
new ComputeSessionResource(std::move(mock_session))));
// The op should request the oracle labels, and probabilities. They should
// be returned in batch major order, so if the label:probability pairs are:
// batch 1 oracle labels: {{1:0.6, 2:0.8}, {3:1.0}, {5:0.7}}
// batch 2 oracle labels: {{2:0.9}, {4:1.0}, {6:0.3, 8:0.6}}
// then the resulting output tensors are:
// indices_output: {0, 0, 1, 2, 3, 4, 5, 5}
// label_output: {1, 2, 3, 5, 2, 4, 6, 8}
// prob_output: {0.6, 0.8, 1.0, 0.7, 0.9, 1.0, 0.3, 0.6}
// Oracle labels along with their probabilities.
const std::vector<std::vector<std::vector<Label>>> oracle_labels(
{{{{1, 0.6}, {2, 0.8}}, {{3, 1.0}}, {{5, 0.7}}},
{{{2, 0.9}}, {{4, 1.0}}, {{6, 0.3}, {8, 0.6}}}});
EXPECT_CALL(*mock_session_ptr, EmitOracleLabels(component_name))
.WillOnce(Return(oracle_labels));
const std::vector<int> expected_indices({0, 0, 1, 2, 3, 4, 5, 5});
const std::vector<int> expected_labels({1, 2, 3, 5, 2, 4, 6, 8});
const std::vector<float> expected_probs(
{0.6, 0.8, 1.0, 0.7, 0.9, 1.0, 0.3, 0.6});
// Run the kernel.
TF_EXPECT_OK(RunOpKernelWithContext());
// Validate the outputs.
EXPECT_EQ(expected_indices.size(), GetOutput(0)->NumElements());
EXPECT_EQ(expected_labels.size(), GetOutput(1)->NumElements());
EXPECT_EQ(expected_probs.size(), GetOutput(2)->NumElements());
for (int i = 0; i < expected_indices.size(); ++i) {
EXPECT_EQ(expected_indices[i], GetOutput(0)->vec<int32>()(i));
EXPECT_EQ(expected_labels[i], GetOutput(1)->vec<int32>()(i));
EXPECT_EQ(expected_probs[i], GetOutput(2)->vec<float>()(i));
}
}
// The EmitAllFinal op should return the result of IsTerminal(component_name).
TEST_F(DragnnOpKernelsTest, EmitAllFinalOpTest) {
// Create and initialize the kernel under test.
......
......@@ -13,7 +13,9 @@
// limitations under the License.
// =============================================================================
#include "dragnn/core/ops/shape_helpers.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/shape_inference.h"
namespace syntaxnet {
namespace dragnn {
......@@ -22,6 +24,10 @@ REGISTER_OP("SetAssetDirectory")
.Input("asset_directory: string")
.Output("asset_directory_out: string")
.SetIsStateful()
.SetShapeFn([](tensorflow::shape_inference::InferenceContext *context) {
context->set_output(0, context->Vector(1));
return ScalarInputShape(0, context);
})
.Doc(R"doc(
Override the paths to assets specified in the MasterSpec with the given
asset_directory. This op must be called before any calls to GetSession, as it
......@@ -38,6 +44,10 @@ REGISTER_OP("GetSession")
.Attr("grid_point: string")
.Output("handle: string")
.SetIsStateful()
.SetShapeFn([](tensorflow::shape_inference::InferenceContext *context) {
TF_RETURN_IF_ERROR(ScalarInputShape(0, context));
return ComputeSessionHandleOutputShape(context);
})
.Doc(R"doc(
Given MasterSpec and GridPoint protos, outputs a handle to a ComputeSession.
......@@ -48,7 +58,11 @@ grid_point: A serialized syntaxnet.dragnn.GridPoint proto.
handle: A string handle to a ComputeSession.
)doc");
REGISTER_OP("ReleaseSession").Input("handle: string").SetIsStateful().Doc(R"doc(
REGISTER_OP("ReleaseSession")
.Input("handle: string")
.SetIsStateful()
.SetShapeFn(ComputeSessionHandleInputShape)
.Doc(R"doc(
Given a ComputeSession, return it to the ComputeSession pool.
This ComputeSession will no longer be available after this op returns.
......@@ -60,6 +74,10 @@ REGISTER_OP("GetSessionCounts")
.Input("container: string")
.Output("stats: int64")
.SetIsStateful()
.SetShapeFn([](tensorflow::shape_inference::InferenceContext *context) {
context->set_output(0, context->Vector(2));
return ScalarInputShape(0, context);
})
.Doc(R"doc(
Given a container string, output session counts for that ComputeSessionPool.
......@@ -68,11 +86,70 @@ stats: A vector of stats. [0] is the total number of created sessions. [1] is
the number of sessions that are currently not in the pool.
)doc");
REGISTER_OP("RebatchDensor")
.Input("dense_data: float")
.Input("offsets: int32")
.Attr("sequence_length: int")
.Attr("lr_padding: int")
.Output("rebatched_data: float")
.Output("rebatched_indices: int32")
.SetIsStateful()
.SetShapeFn([](tensorflow::shape_inference::InferenceContext *context) {
int sequence_length;
TF_RETURN_IF_ERROR(context->GetAttr("sequence_length", &sequence_length));
int lr_padding;
TF_RETURN_IF_ERROR(context->GetAttr("lr_padding", &lr_padding));
const int output_sequence_length = 2 * lr_padding + sequence_length;
TF_RETURN_IF_ERROR(MatrixInputShape(0, context));
const auto embedding_dim = context->Dim(context->input(0), 1);
context->set_output(
0, context->MakeShape({context->UnknownDim(), output_sequence_length,
embedding_dim}));
VectorOutputShape(1, context);
return VectorInputShape(1, context);
})
.Doc(R"doc(
Rebatch a dense ragged tensor into a set of fixed-size subsequences.
dense_data: A tensor containing the dense ragged data.
offsets: The passage offsets into the dense_data tensor.
sequence_length: The size of the sequence length to rebatch to.
lr_padding: The amount of context to pad when breaking a passage.
)doc");
REGISTER_OP("UnbatchSubsequences")
.Input("data: float")
.Input("indices: int32")
.Input("offsets: int32")
.Output("rebatched_data: float")
.SetIsStateful()
.SetShapeFn([](tensorflow::shape_inference::InferenceContext *context) {
TF_RETURN_IF_ERROR(TensorInputShape(0, 4, context));
const auto embedding_dim = context->Dim(context->input(0), 3);
context->set_output(
0, context->MakeShape({context->UnknownDim(), context->UnknownDim(),
embedding_dim}));
TF_RETURN_IF_ERROR(VectorInputShape(1, context));
return VectorInputShape(2, context);
})
.Doc(R"doc(
Rebatch a dense ragged tensor into a set of fixed-size subsequences.
data: A tensor containing the fixed-length subsequences to unbatch.
indices: A tensor mapping the subsequences to the original sequences.
offsets: The passage offsets used to create the subsequences.
)doc");
REGISTER_OP("InitComponentData")
.Input("handle: string")
.Input("beam_size: int32")
.Attr("component: string")
.Output("output_handle: string")
.SetShapeFn([](tensorflow::shape_inference::InferenceContext *context) {
TF_RETURN_IF_ERROR(ScalarInputShape(1, context));
return ComputeSessionHandleInputAndOutputShape(context);
})
.Doc(R"doc(
Initialize a component with the given beam size for a given ComputeSession.
......@@ -86,6 +163,10 @@ REGISTER_OP("BatchSize")
.Input("handle: string")
.Attr("component: string")
.Output("batch_size: int32")
.SetShapeFn([](tensorflow::shape_inference::InferenceContext *context) {
ScalarOutputShape(0, context);
return ComputeSessionHandleInputShape(context);
})
.Doc(R"doc(
Given a ComputeSession and a component name,return the component batch size.
......@@ -99,6 +180,10 @@ REGISTER_OP("SetTracing")
.Input("tracing_on: bool")
.Attr("component: string = 'NOT_USED_FOR_THIS_OP'")
.Output("output_handle: string")
.SetShapeFn([](tensorflow::shape_inference::InferenceContext *context) {
TF_RETURN_IF_ERROR(ScalarInputShape(1, context));
return ComputeSessionHandleInputAndOutputShape(context);
})
.Doc(R"doc(
Given a ComputeSession, turns on or off tracing for all components.
......@@ -112,6 +197,10 @@ REGISTER_OP("AttachDataReader")
.Input("input_spec: string")
.Attr("component: string = 'NOT_USED_FOR_THIS_OP'")
.Output("output_handle: string")
.SetShapeFn([](tensorflow::shape_inference::InferenceContext *context) {
TF_RETURN_IF_ERROR(VectorInputShape(1, context));
return ComputeSessionHandleInputAndOutputShape(context);
})
.Doc(R"doc(
Given a ComputeSession, attach a data source.
......@@ -127,6 +216,7 @@ REGISTER_OP("AdvanceFromOracle")
.Input("handle: string")
.Attr("component: string")
.Output("output_handle: string")
.SetShapeFn(ComputeSessionHandleInputAndOutputShape)
.Doc(R"doc(
Given a ComputeSession and a Component name, advance the component via oracle.
......@@ -140,6 +230,10 @@ REGISTER_OP("AdvanceFromPrediction")
.Input("scores: float")
.Attr("component: string")
.Output("output_handle: string")
.SetShapeFn([](tensorflow::shape_inference::InferenceContext *context) {
TF_RETURN_IF_ERROR(MatrixInputShape(1, context));
return ComputeSessionHandleInputAndOutputShape(context);
})
.Doc(R"doc(
Given a ComputeSession, a Component name, and a score tensor, advance the state.
......@@ -156,6 +250,12 @@ REGISTER_OP("ExtractFixedFeatures")
.Output("weights: float")
.Attr("component: string")
.Attr("channel_id: int")
.SetShapeFn([](tensorflow::shape_inference::InferenceContext *context) {
VectorOutputShape(0, context);
VectorOutputShape(1, context);
VectorOutputShape(2, context);
return ComputeSessionHandleInputShape(context);
})
.Doc(R"doc(
Given a ComputeSession, Component, and channel index, output fixed features.
......@@ -179,6 +279,11 @@ REGISTER_OP("ExtractLinkFeatures")
.Output("idx: int32")
.Attr("component: string")
.Attr("channel_id: int")
.SetShapeFn([](tensorflow::shape_inference::InferenceContext *context) {
VectorOutputShape(0, context);
VectorOutputShape(1, context);
return ComputeSessionHandleInputShape(context);
})
.Doc(R"doc(
Given a ComputeSession, Component, and a channel index, outputs link features.
......@@ -195,6 +300,10 @@ REGISTER_OP("EmitOracleLabels")
.Input("handle: string")
.Output("gold_labels: int32")
.Attr("component: string")
.SetShapeFn([](tensorflow::shape_inference::InferenceContext *context) {
VectorOutputShape(0, context);
return ComputeSessionHandleInputShape(context);
})
.Doc(R"doc(
Given a ComputeSession and Component, emit a vector of gold labels.
......@@ -204,10 +313,39 @@ gold_labels: A [batch_size * beam_size] vector of gold labels for the current
component: The name of a Component instance, matching the ComponentSpec.name.
)doc");
REGISTER_OP("EmitOracleLabelsAndProbabilities")
.Input("handle: string")
.Output("instance_indices: int32")
.Output("gold_labels: int32")
.Output("probabilities: float")
.Attr("component: string")
.SetShapeFn([](tensorflow::shape_inference::InferenceContext *context) {
VectorOutputShape(0, context);
VectorOutputShape(1, context);
VectorOutputShape(2, context);
return ComputeSessionHandleInputShape(context);
})
.Doc(R"doc(
Given a ComputeSession and Component, emit corresponding vectors of instance
indices, gold labels, and probabilities.
handle: A handle to a ComputeSession.
instance_indices: A vector [N] of indices for the current ComputeSession, where
N is the number of instance labels. Each element in each beam is
assigned an index.
gold_labels: A vector [N] of gold labels for the current ComputeSession.
probabilities: A vector [N] of probabilities for the current ComputeSession.
component: The name of a Component instance, matching the ComponentSpec.name.
)doc");
REGISTER_OP("EmitAllFinal")
.Input("handle: string")
.Output("all_final: bool")
.Attr("component: string")
.SetShapeFn([](tensorflow::shape_inference::InferenceContext *context) {
context->set_output(0, context->Vector(1));
return ComputeSessionHandleInputShape(context);
})
.Doc(R"doc(
Given a ComputeSession and Component, returns whether the Component is final.
......@@ -223,6 +361,7 @@ REGISTER_OP("WriteAnnotations")
.Input("handle: string")
.Output("output_handle: string")
.Attr("component: string")
.SetShapeFn(ComputeSessionHandleInputAndOutputShape)
.Doc(R"doc(
Given a ComputeSession, has the given component write out its annotations.
......@@ -238,6 +377,10 @@ REGISTER_OP("EmitAnnotations")
.Input("handle: string")
.Output("annotations: string")
.Attr("component: string")
.SetShapeFn([](tensorflow::shape_inference::InferenceContext *context) {
VectorOutputShape(0, context);
return ComputeSessionHandleInputShape(context);
})
.Doc(R"doc(
Given a ComputeSession, emits strings with final predictions for the model.
......@@ -252,6 +395,10 @@ REGISTER_OP("GetComponentTrace")
.Input("handle: string")
.Output("trace: string")
.Attr("component: string")
.SetShapeFn([](tensorflow::shape_inference::InferenceContext *context) {
VectorOutputShape(0, context);
return ComputeSessionHandleInputShape(context);
})
.Doc(R"doc(
Gets the raw MasterTrace proto for each batch, state, and beam slot.
......
// Copyright 2018 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
// Shape inference functions for DRAGNN ops.
#ifndef DRAGNN_CORE_OPS_SHAPE_HELPERS_H_
#define DRAGNN_CORE_OPS_SHAPE_HELPERS_H_
#include "syntaxnet/ops/shape_helpers.h"
#include "tensorflow/core/framework/shape_inference.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
namespace syntaxnet {
namespace dragnn {
// Returns OK if the 0'th input of the |context| is compatible with the shape of
// a ComputeSession handle.
inline tensorflow::Status ComputeSessionHandleInputShape(
tensorflow::shape_inference::InferenceContext *context) {
tensorflow::shape_inference::ShapeHandle unused;
return context->Merge(context->input(0), context->Vector(2), &unused);
}
// Sets the 0'th output of the |context| to have the shape of a ComputeSession
// handle. Always returns OK.
inline tensorflow::Status ComputeSessionHandleOutputShape(
tensorflow::shape_inference::InferenceContext *context) {
context->set_output(0, context->Vector(2));
return tensorflow::Status::OK();
}
// For convenience, combines ComputeSessionHandle{Input,Output}Shape().
inline tensorflow::Status ComputeSessionHandleInputAndOutputShape(
tensorflow::shape_inference::InferenceContext *context) {
TF_RETURN_IF_ERROR(ComputeSessionHandleInputShape(context));
return ComputeSessionHandleOutputShape(context);
}
} // namespace dragnn
} // namespace syntaxnet
#endif // DRAGNN_CORE_OPS_SHAPE_HELPERS_H_
......@@ -12,8 +12,9 @@ cc_library(
"//dragnn/core:index_translator",
"//dragnn/core/interfaces:component",
"//dragnn/core/interfaces:transition_state",
"//dragnn/protos:data_proto",
"//dragnn/protos:spec_proto",
"//dragnn/core/util:label",
"//dragnn/protos:data_proto_cc",
"//dragnn/protos:spec_proto_cc",
"//syntaxnet:base",
"//syntaxnet:test_main",
],
......@@ -27,8 +28,9 @@ cc_library(
"//dragnn/components/util:bulk_feature_extractor",
"//dragnn/core:compute_session",
"//dragnn/core:input_batch_cache",
"//dragnn/protos:data_proto",
"//dragnn/protos:spec_proto",
"//dragnn/core/util:label",
"//dragnn/protos:data_proto_cc",
"//dragnn/protos:spec_proto_cc",
"//syntaxnet:base",
"//syntaxnet:test_main",
],
......@@ -45,6 +47,12 @@ cc_library(
],
)
cc_library(
name = "fake_component_base",
hdrs = ["fake_component_base.h"],
deps = ["//dragnn/core/interfaces:component"],
)
cc_library(
name = "generic",
testonly = True,
......
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#ifndef DRAGNN_CORE_TEST_FAKE_COMPONENT_BASE_H_
#define DRAGNN_CORE_TEST_FAKE_COMPONENT_BASE_H_
#include "dragnn/core/interfaces/component.h"
#include "dragnn/protos/data.pb.h"
namespace syntaxnet {
namespace dragnn {
// Define a test component to validate registered construction.
class FakeComponentBase : public Component {
public:
FakeComponentBase() {}
void InitializeComponent(const ComponentSpec &spec) override {
name_ = spec.name();
}
void InitializeData(
const std::vector<std::vector<const TransitionState *>> &states,
int max_beam_size, InputBatchCache *input_data) override {}
void InitializeTracing() override {}
void DisableTracing() override {}
bool IsReady() const override { return true; }
string Name() const override { return name_; }
int BeamSize() const override { return 1; }
int BatchSize() const override { return 1; }
int StepsTaken(int batch_index) const override { return 0; }
int GetBeamIndexAtStep(int step, int current_index,
int batch) const override {
return 0;
}
int GetSourceBeamIndex(int current_index, int batch) const override {
return 0;
}
bool AdvanceFromPrediction(const float *score_matrix, int num_items,
int num_actions) override {
return true;
}
void AdvanceFromOracle() override {}
bool IsTerminal() const override { return true; }
std::function<int(int, int, int)> GetStepLookupFunction(
const string &method) override {
return nullptr;
}
std::vector<std::vector<const TransitionState *>> GetBeam() override {
std::vector<std::vector<const TransitionState *>> states;
return states;
}
int GetFixedFeatures(std::function<int32 *(int)> allocate_indices,
std::function<int64 *(int)> allocate_ids,
std::function<float *(int)> allocate_weights,
int channel_id) const override {
return 0;
}
void BulkEmbedFixedFeatures(
int batch_size_padding, int num_steps_padding, int embedding_size,
const vector<const float *> &per_channel_embeddings,
float *embedding_output) override {}
void BulkEmbedDenseFixedFeatures(
const vector<const float *> &per_channel_embeddings,
float *embedding_output, int embedding_output_size,
int *offset_array_output, int offset_array_size) override {}
int BulkDenseFeatureSize() const override { return 0; }
int BulkGetFixedFeatures(const BulkFeatureExtractor &extractor) override {
return 0;
}
std::vector<LinkFeatures> GetRawLinkFeatures(int channel_id) const override {
std::vector<LinkFeatures> ret;
return ret;
}
std::vector<std::vector<std::vector<Label>>> GetOracleLabels()
const override {
std::vector<std::vector<std::vector<Label>>> ret;
return ret;
}
void FinalizeData() override {}
void ResetComponent() override {}
std::vector<std::vector<ComponentTrace>> GetTraceProtos() const override {
std::vector<std::vector<ComponentTrace>> ret;
return ret;
}
void AddTranslatedLinkFeaturesToTrace(
const std::vector<LinkFeatures> &features, int channel_id) override {}
string name_;
};
} // namespace dragnn
} // namespace syntaxnet
#endif // DRAGNN_CORE_TEST_FAKE_COMPONENT_BASE_H_
......@@ -27,7 +27,8 @@
namespace syntaxnet {
namespace test {
MATCHER_P(EqualsProto, a, "Protos are not equivalent:") {
MATCHER_P(EqualsProto, a,
"Protos " + string(negation ? "aren't" : "are") + " equivalent:") {
return a.DebugString() == arg.DebugString();
}
......@@ -39,6 +40,16 @@ MATCHER_P(IsErrorWithSubstr, substr,
return !arg.ok() && arg.error_message().find(substr) != string::npos;
}
// Matches an error status whose code and message match |code| and |substr|.
MATCHER_P2(IsErrorWithCodeAndSubstr, code, substr,
string(negation ? "isn't" : "is") +
" an error Status whose code is " + ::testing::PrintToString(code) +
" and whose message matches the substring '" +
::testing::PrintToString(substr) + "'") {
return !arg.ok() && arg.code() == code &&
arg.error_message().find(substr) != string::npos;
}
// Returns the prefix for where the test data is stored.
string GetTestDataPrefix();
......
......@@ -22,6 +22,7 @@
#include "dragnn/core/index_translator.h"
#include "dragnn/core/interfaces/component.h"
#include "dragnn/core/interfaces/transition_state.h"
#include "dragnn/core/util/label.h"
#include "dragnn/protos/data.pb.h"
#include "dragnn/protos/spec.pb.h"
#include "syntaxnet/base.h"
......@@ -64,9 +65,15 @@ class MockComponent : public Component {
int output_array_size,
const vector<const float *> &per_channel_embeddings,
float *embedding_output));
MOCK_METHOD5(BulkEmbedDenseFixedFeatures,
void(const vector<const float *> &per_channel_embeddings,
float *embedding_output, int embedding_output_size,
int32 *offset_array_output, int offset_array_size));
MOCK_CONST_METHOD0(BulkDenseFeatureSize, int());
MOCK_CONST_METHOD1(GetRawLinkFeatures,
std::vector<LinkFeatures>(int channel_id));
MOCK_CONST_METHOD0(GetOracleLabels, std::vector<std::vector<int>>());
MOCK_CONST_METHOD0(GetOracleLabels,
std::vector<std::vector<std::vector<Label>>>());
MOCK_METHOD0(ResetComponent, void());
MOCK_METHOD1(GetStepLookupFunction,
std::function<int(int, int, int)>(const string &method));
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment