Commit 4364390a authored by Ivan Bogatyy's avatar Ivan Bogatyy Committed by calberti
Browse files

Release DRAGNN bulk networks (#2785)

* Release DRAGNN bulk networks
parent 638fd759
...@@ -80,6 +80,38 @@ pad_to_batch: If set, the op will pad/truncate to this number of elements. ...@@ -80,6 +80,38 @@ pad_to_batch: If set, the op will pad/truncate to this number of elements.
pad_to_steps: If set, the op will pad/truncate to this number of steps. pad_to_steps: If set, the op will pad/truncate to this number of steps.
)doc"); )doc");
REGISTER_OP("BulkEmbedFixedFeatures")
.Input("handle: string")
.Input("embedding_matrix: num_channels * float")
.Output("output_handle: string")
.Output("embedding_vectors: float")
.Output("num_steps: int32")
.Attr("component: string")
.Attr("num_channels: int")
.Attr("pad_to_batch: int")
.Attr("pad_to_steps: int")
.SetIsStateful()
.Doc(R"doc(
This op is a more efficient version of BulkFixedFeatures.
It is intended to be run with large batch sizes at inference time. The op takes
a handle to ComputeSession and embedding matrices as tensor inputs, and directly
outputs concatenated embedding vectors. It calls the BulkEmbedFixedFeatures
method on the underlying component directly, so it requires a padding vector
to be passed.
handle: A handle to ComputeSession.
embedding_matrix: Embedding matrices.
output_handle: A handle to the same ComputeSession after advancement.
embedding_vectors: (matrix of float) Concatenated embeddings,
shaped as (batch * beam * token) x sum_channel(embedding_dim[channel]).
num_steps: The batch was unrolled for these many steps.
component: The name of a Component instance, matching the ComponentSpec.name.
num_channels: The number of FixedFeature channels.
pad_to_batch: The op will pad/truncate to this number of elements.
pad_to_steps: The op will pad/truncate to this number of steps.
)doc");
REGISTER_OP("BulkAdvanceFromOracle") REGISTER_OP("BulkAdvanceFromOracle")
.Input("handle: string") .Input("handle: string")
.Output("output_handle: string") .Output("output_handle: string")
......
...@@ -30,6 +30,7 @@ ...@@ -30,6 +30,7 @@
#include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/threadpool.h" #include "tensorflow/core/lib/core/threadpool.h"
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/mutex.h"
...@@ -40,6 +41,8 @@ using tensorflow::DT_INT32; ...@@ -40,6 +41,8 @@ using tensorflow::DT_INT32;
using tensorflow::DT_INT64; using tensorflow::DT_INT64;
using tensorflow::DT_STRING; using tensorflow::DT_STRING;
using tensorflow::DataType; using tensorflow::DataType;
using tensorflow::io::Dirname;
using tensorflow::io::JoinPath;
using tensorflow::OpKernel; using tensorflow::OpKernel;
using tensorflow::OpKernelConstruction; using tensorflow::OpKernelConstruction;
using tensorflow::OpKernelContext; using tensorflow::OpKernelContext;
...@@ -53,6 +56,59 @@ namespace dragnn { ...@@ -53,6 +56,59 @@ namespace dragnn {
typedef ResourceContainer<ComputeSession> ComputeSessionResource; typedef ResourceContainer<ComputeSession> ComputeSessionResource;
typedef ResourceContainer<ComputeSessionPool> ComputeSessionPoolResource; typedef ResourceContainer<ComputeSessionPool> ComputeSessionPoolResource;
typedef ResourceContainer<string> StringResource;
namespace {
const char kGlobalContainer[] = "__reserved_global_container";
const char kBasePathTag[] = "__reserved_asset_base_path";
const char kUnmanagedAssetDirectory[] = "assets.extra";
// When restoring a graph from a SavedModel, this op will rewrite the MasterSpec
// to point the DRAGNN components to the new resource locations. It will then
// add a string resource to the resource manager, which will be used to
// rebuild the masterspec before it is acquired in the GetComputeSession op.
class SetAssetDirectory : public OpKernel {
public:
explicit SetAssetDirectory(OpKernelConstruction *context)
: OpKernel(context) {
OP_REQUIRES_OK(context, context->MatchSignature({DT_STRING}, {DT_STRING}));
}
void Compute(OpKernelContext *context) override {
ResourceMgr *rmgr = context->resource_manager();
const string asset_path = context->input(0).scalar<string>()();
// TODO(googleuser): Get this data in a way that isn't fragile as all hell.
// "I've done stuff I ain't proud of... and the stuff I am proud of is
// disgusting." -- Moe
auto extra_asset_dir =
JoinPath(Dirname(Dirname(asset_path)), kUnmanagedAssetDirectory);
LOG(INFO) << "Found extra assets path at:" << extra_asset_dir;
// Rather than attempt to rewrite the MasterSpec here, we save off a
// StringResource containing the new asset path. It will be used in
// the GetSession op, if it exists.
std::unique_ptr<string> asset_path_ptr(new string(extra_asset_dir));
OP_REQUIRES_OK(context, rmgr->Create<StringResource>(
kGlobalContainer, kBasePathTag,
new StringResource(std::move(asset_path_ptr))));
// This isn't used anywhere - it just allows us to have an output so that
// it's easier to reason about Tensorflow's graph execution.
Tensor *output;
OP_REQUIRES_OK(context,
context->allocate_output(0, TensorShape({1}), &output));
output->vec<string>()(0) = asset_path;
}
private:
TF_DISALLOW_COPY_AND_ASSIGN(SetAssetDirectory);
};
REGISTER_KERNEL_BUILDER(Name("SetAssetDirectory").Device(DEVICE_CPU),
SetAssetDirectory);
// Given a MasterSpec proto, outputs a handle to a ComputeSession. // Given a MasterSpec proto, outputs a handle to a ComputeSession.
class GetSession : public OpKernel { class GetSession : public OpKernel {
...@@ -66,6 +122,7 @@ class GetSession : public OpKernel { ...@@ -66,6 +122,7 @@ class GetSession : public OpKernel {
CHECK(master_spec_.ParseFromString(master_spec_str)); CHECK(master_spec_.ParseFromString(master_spec_str));
CHECK(grid_point_.ParseFromString(grid_point_spec_str)); CHECK(grid_point_.ParseFromString(grid_point_spec_str));
OP_REQUIRES_OK(context, context->MatchSignature({DT_STRING}, {DT_STRING})); OP_REQUIRES_OK(context, context->MatchSignature({DT_STRING}, {DT_STRING}));
has_overwritten_spec_ = false;
} }
void Compute(OpKernelContext *context) override { void Compute(OpKernelContext *context) override {
...@@ -74,10 +131,32 @@ class GetSession : public OpKernel { ...@@ -74,10 +131,32 @@ class GetSession : public OpKernel {
// Create the pool for this container, or re-use one that was allocated in a // Create the pool for this container, or re-use one that was allocated in a
// previous call. // previous call.
auto create_pool = [this, auto create_pool = [this, &rmgr,
&container](ComputeSessionPoolResource **resource) { &container](ComputeSessionPoolResource **resource) {
LOG(INFO) << "Creating new ComputeSessionPool in container handle: " if (has_overwritten_spec_) {
<< container; // TODO(googleuser): Figure out a way to test this.
// If there's already an overwritten spec, use that.
LOG(INFO) << "Creating new ComputeSessionPool in container handle: "
<< container << " with previously overwritten master spec.";
} else {
// If not, try to find the resource base.
StringResource *resource_base;
auto resource_base_lookup = rmgr->Lookup<StringResource>(
kGlobalContainer, kBasePathTag, &resource_base);
if (resource_base_lookup.ok()) {
// If that exists, the spec must be rewritten.
string resource_base_path = *resource_base->get();
LOG(INFO) << "Creating new ComputeSessionPool in container handle: "
<< container << " using resource directory base "
<< resource_base_path;
RewriteMasterSpec(resource_base_path);
resource_base->Unref();
} else {
// If not, just use the spec as is.
LOG(INFO) << "Creating new ComputeSessionPool in container handle: "
<< container << " without editing master spec.";
}
}
std::unique_ptr<ComputeSessionPool> pool( std::unique_ptr<ComputeSessionPool> pool(
new ComputeSessionPool(master_spec_, grid_point_)); new ComputeSessionPool(master_spec_, grid_point_));
*resource = new ComputeSessionPoolResource(std::move(pool)); *resource = new ComputeSessionPoolResource(std::move(pool));
...@@ -120,6 +199,23 @@ class GetSession : public OpKernel { ...@@ -120,6 +199,23 @@ class GetSession : public OpKernel {
} }
private: private:
// Rewrites this op's saved MasterSpec, appending the new base directory.
void RewriteMasterSpec(const string &new_base) {
for (auto &component_spec : *master_spec_.mutable_component()) {
for (auto &resource_def : *component_spec.mutable_resource()) {
for (auto &part_def : *resource_def.mutable_part()) {
part_def.set_file_pattern(
JoinPath(new_base, part_def.file_pattern()));
VLOG(2) << "New path: " << part_def.file_pattern();
}
}
}
VLOG(3) << "Rewritten spec: " << master_spec_.DebugString();
has_overwritten_spec_ = true;
}
bool has_overwritten_spec_;
MasterSpec master_spec_; MasterSpec master_spec_;
GridPoint grid_point_; GridPoint grid_point_;
...@@ -141,7 +237,6 @@ REGISTER_KERNEL_BUILDER(Name("GetSession").Device(DEVICE_CPU), GetSession); ...@@ -141,7 +237,6 @@ REGISTER_KERNEL_BUILDER(Name("GetSession").Device(DEVICE_CPU), GetSession);
class ReleaseSession : public OpKernel { class ReleaseSession : public OpKernel {
public: public:
explicit ReleaseSession(OpKernelConstruction *context) : OpKernel(context) { explicit ReleaseSession(OpKernelConstruction *context) : OpKernel(context) {
string master_spec_str;
OP_REQUIRES_OK(context, context->MatchSignature({DT_STRING}, {})); OP_REQUIRES_OK(context, context->MatchSignature({DT_STRING}, {}));
} }
...@@ -188,6 +283,53 @@ class ReleaseSession : public OpKernel { ...@@ -188,6 +283,53 @@ class ReleaseSession : public OpKernel {
REGISTER_KERNEL_BUILDER(Name("ReleaseSession").Device(DEVICE_CPU), REGISTER_KERNEL_BUILDER(Name("ReleaseSession").Device(DEVICE_CPU),
ReleaseSession); ReleaseSession);
// Returns statistics about session loads to the graph. This op returns the
// total number of created Session objects and the number of those objects
// that are currently being used in the ComputeSessionPool.
class GetSessionCounts : public OpKernel {
public:
explicit GetSessionCounts(OpKernelConstruction *context) : OpKernel(context) {
OP_REQUIRES_OK(context, context->MatchSignature({DT_STRING}, {DT_INT64}));
}
void Compute(OpKernelContext *context) override {
const string container = context->input(0).scalar<string>()();
VLOG(1) << "Getting stats for container: " << container;
ResourceMgr *rmgr = context->resource_manager();
// Allocate the output tensors.
Tensor *output;
OP_REQUIRES_OK(context,
context->allocate_output(0, TensorShape({2}), &output));
// Get the pool for this container.
ComputeSessionPoolResource *pool_resource;
auto result = rmgr->Lookup<ComputeSessionPoolResource>(container, "pool",
&pool_resource);
if (!result.ok()) {
// If there's no ComputeSessionPoolResource, report 0 sessions created
// and 0 available.
output->vec<int64>()(0) = 0;
output->vec<int64>()(1) = 0;
return;
}
auto *pool = pool_resource->get();
CHECK(pool != nullptr);
output->vec<int64>()(0) = pool->num_unique_sessions();
output->vec<int64>()(1) = pool->num_outstanding_sessions();
pool_resource->Unref();
}
private:
TF_DISALLOW_COPY_AND_ASSIGN(GetSessionCounts);
};
REGISTER_KERNEL_BUILDER(Name("GetSessionCounts").Device(DEVICE_CPU),
GetSessionCounts);
/******************************************************************************* /*******************************************************************************
* ComputeSessionOps below here. * ComputeSessionOps below here.
******************************************************************************/ ******************************************************************************/
...@@ -233,9 +375,17 @@ class AdvanceFromPrediction : public ComputeSessionOp { ...@@ -233,9 +375,17 @@ class AdvanceFromPrediction : public ComputeSessionOp {
void ComputeWithState(OpKernelContext *context, void ComputeWithState(OpKernelContext *context,
ComputeSession *session) override { ComputeSession *session) override {
const Tensor &scores = context->input(1); const Tensor &scores = context->input(1);
session->AdvanceFromPrediction(component_name(), const int num_items = scores.shape().dim_size(0);
scores.tensor<float, 2>().data(), const int num_actions = scores.shape().dim_size(1);
scores.NumElements()); bool success = session->AdvanceFromPrediction(
component_name(), scores.tensor<float, 2>().data(), num_items,
num_actions);
if (success) {
VLOG(2) << "Score: " << scores.tensor<float, 2>();
}
OP_REQUIRES(
context, success,
tensorflow::errors::Internal("Unable to advance from prediction."));
} }
private: private:
...@@ -247,13 +397,12 @@ REGISTER_KERNEL_BUILDER(Name("AdvanceFromPrediction").Device(DEVICE_CPU), ...@@ -247,13 +397,12 @@ REGISTER_KERNEL_BUILDER(Name("AdvanceFromPrediction").Device(DEVICE_CPU),
// Given a handle to a ComputeSession and a channel index, outputs fixed // Given a handle to a ComputeSession and a channel index, outputs fixed
// features. // features.
// Fixed features are returned as 3 vectors or equal length: // Fixed features are returned as 3 vectors of equal length:
// - ids: specifies which rows should be looked up in the embedding // - ids: specifies which rows should be looked up in the embedding
// matrix, // matrix,
// - weights: specifies a scale for each embedding vector, // - weights: specifies a scale for each embedding vector,
// - indices: sorted vector that assigns the same index to embedding // - indices: sorted vector that assigns the same index to embedding
// vectors // vectors that should be summed together.
// that should be summed together.
// //
// For example if we have 3 features, for a given channel, we might have: // For example if we have 3 features, for a given channel, we might have:
// feature a: (5, 1) // feature a: (5, 1)
...@@ -300,7 +449,10 @@ class ExtractFixedFeatures : public ComputeSessionOp { ...@@ -300,7 +449,10 @@ class ExtractFixedFeatures : public ComputeSessionOp {
int num_features = session->GetInputFeatures( int num_features = session->GetInputFeatures(
component_name(), indices_allocator, ids_allocator, weights_allocator, component_name(), indices_allocator, ids_allocator, weights_allocator,
channel_id_); channel_id_);
VLOG(2) << "Extracted " << num_features; VLOG(2) << "Extracted features (" << num_features << "): "
<< " ids=" << context->mutable_output(1)->vec<int64>()
<< " weights=" << context->mutable_output(2)->vec<float>()
<< " indices=" << context->mutable_output(0)->vec<int32>();
} }
private: private:
...@@ -524,6 +676,7 @@ class AttachDataReader : public ComputeSessionOp { ...@@ -524,6 +676,7 @@ class AttachDataReader : public ComputeSessionOp {
auto input_data(context->input(1).vec<string>()); auto input_data(context->input(1).vec<string>());
std::vector<string> data; std::vector<string> data;
data.reserve(input_data.size());
for (int i = 0; i < input_data.size(); ++i) { for (int i = 0; i < input_data.size(); ++i) {
data.push_back(input_data(i)); data.push_back(input_data(i));
} }
...@@ -642,5 +795,6 @@ class GetComponentTrace : public ComputeSessionOp { ...@@ -642,5 +795,6 @@ class GetComponentTrace : public ComputeSessionOp {
REGISTER_KERNEL_BUILDER(Name("GetComponentTrace").Device(DEVICE_CPU), REGISTER_KERNEL_BUILDER(Name("GetComponentTrace").Device(DEVICE_CPU),
GetComponentTrace); GetComponentTrace);
} // namespace
} // namespace dragnn } // namespace dragnn
} // namespace syntaxnet } // namespace syntaxnet
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#include <memory> #include <memory>
#include <vector> #include <vector>
#include "dragnn/core/component_registry.h"
#include "dragnn/core/compute_session.h" #include "dragnn/core/compute_session.h"
#include "dragnn/core/compute_session_pool.h" #include "dragnn/core/compute_session_pool.h"
#include "dragnn/core/resource_container.h" #include "dragnn/core/resource_container.h"
...@@ -66,6 +67,87 @@ using testing::Return; ...@@ -66,6 +67,87 @@ using testing::Return;
typedef ResourceContainer<ComputeSession> ComputeSessionResource; typedef ResourceContainer<ComputeSession> ComputeSessionResource;
typedef ResourceContainer<ComputeSessionPool> ComputeSessionPoolResource; typedef ResourceContainer<ComputeSessionPool> ComputeSessionPoolResource;
typedef ResourceContainer<string> StringResource;
namespace {
const char kGlobalContainer[] = "__reserved_global_container";
const char kBasePathTag[] = "__reserved_asset_base_path";
const char kUnmanagedAssetDirectory[] = "assets.extra";
} // namespace
// Define a test component to validate registered construction.
class TestComponent : public Component {
public:
TestComponent() {}
void InitializeComponent(const ComponentSpec &spec) override {
name_ = spec.name();
}
void InitializeData(
const std::vector<std::vector<const TransitionState *>> &states,
int max_beam_size, InputBatchCache *input_data) override {}
void InitializeTracing() override {}
void DisableTracing() override {}
bool IsReady() const override { return true; }
string Name() const override { return name_; }
int BeamSize() const override { return 3; }
int BatchSize() const override { return 1; }
int StepsTaken(int batch_index) const override { return 0; }
int GetBeamIndexAtStep(int step, int current_index,
int batch) const override {
return 0;
}
int GetSourceBeamIndex(int current_index, int batch) const override {
return 0;
}
bool AdvanceFromPrediction(const float *score_matrix, int num_items,
int num_actions) override {
return true;
}
void AdvanceFromOracle() override {}
bool IsTerminal() const override { return true; }
std::function<int(int, int, int)> GetStepLookupFunction(
const string &method) override {
return nullptr;
}
std::vector<std::vector<const TransitionState *>> GetBeam() override {
std::vector<std::vector<const TransitionState *>> states;
return states;
}
int GetFixedFeatures(std::function<int32 *(int)> allocate_indices,
std::function<int64 *(int)> allocate_ids,
std::function<float *(int)> allocate_weights,
int channel_id) const override {
return 0;
}
int BulkGetFixedFeatures(const BulkFeatureExtractor &extractor) override {
return 0;
}
void BulkEmbedFixedFeatures(
int batch_size_padding, int num_steps_padding, int output_array_size,
const vector<const float *> &per_channel_embeddings,
float *embedding_matrix) override {}
std::vector<LinkFeatures> GetRawLinkFeatures(int channel_id) const override {
std::vector<LinkFeatures> ret;
return ret;
}
std::vector<std::vector<int>> GetOracleLabels() const override {
std::vector<std::vector<int>> ret;
return ret;
}
void FinalizeData() override {}
void ResetComponent() override {}
std::vector<std::vector<ComponentTrace>> GetTraceProtos() const override {
std::vector<std::vector<ComponentTrace>> ret;
return ret;
}
void AddTranslatedLinkFeaturesToTrace(
const std::vector<LinkFeatures> &features, int channel_id) override {}
string name_;
};
REGISTER_DRAGNN_COMPONENT(TestComponent);
class DragnnOpKernelsTest : public tensorflow::OpsTestBase { class DragnnOpKernelsTest : public tensorflow::OpsTestBase {
public: public:
...@@ -106,6 +188,42 @@ LinkFeatures MakeFeatures(int batch_index, int beam_index, int step) { ...@@ -106,6 +188,42 @@ LinkFeatures MakeFeatures(int batch_index, int beam_index, int step) {
return features; return features;
} }
// The SetAssetDirectory op should
// 1. When given an asset path (foo/bar/baz/asset/thing), strip the path to
// foo/bar/baz and add 'assets.extra' to it.
// 2. Store that path in the resource manager.
TEST_F(DragnnOpKernelsTest, SetAssetDirectoryTest) {
// Create a MasterSpec and GridPoint string to pass into the attrs for this
// op.
const string new_asset_path = "new/directory/path/asset/master_spec";
const string expected_asset_path =
StrCat("new/directory/path/", kUnmanagedAssetDirectory);
// Create and initialize the kernel under test.
TF_ASSERT_OK(NodeDefBuilder("set_asset_directory", "SetAssetDirectory")
.Input(FakeInput(DT_STRING)) // The new asset path.
.Finalize(node_def()));
TF_ASSERT_OK(InitOp());
// Set the input data.
AddInputFromList<string>(TensorShape({1}), {new_asset_path});
// Reset the test context to ensure it's clean.
ResetOpKernelContext();
// Run the kernel.
TF_EXPECT_OK(RunOpKernelWithContext());
// Expect that the ResourceMgr contains a the correct string.
StringResource *resource;
TF_EXPECT_OK(resource_mgr()->Lookup<StringResource>(kGlobalContainer,
kBasePathTag, &resource));
EXPECT_EQ(*resource->get(), expected_asset_path);
resource->Unref();
}
// The GetSessionOp should // The GetSessionOp should
// 1. create a ComputeSessionPool resource and store it in the ResourceMgr, // 1. create a ComputeSessionPool resource and store it in the ResourceMgr,
// 2. create a ComputeSession resource and store it in the ResourceMgr, // 2. create a ComputeSession resource and store it in the ResourceMgr,
...@@ -164,6 +282,103 @@ TEST_F(DragnnOpKernelsTest, GetSessionOpTest) { ...@@ -164,6 +282,103 @@ TEST_F(DragnnOpKernelsTest, GetSessionOpTest) {
pool_resource->Unref(); pool_resource->Unref();
} }
// If an asset_base_path resource exists, the GetSession op should prepend
// that path to all paths in the MasterSpec before creating a session.
TEST_F(DragnnOpKernelsTest, GetSessionWithAssetBasePathTest) {
// Create a MasterSpec and GridPoint string to pass into the attrs for this
// op.
const string new_asset_path = "new/base";
MasterSpec spec;
// The first component in the MasterSpec has one resource with one part.
auto component_one = spec.add_component();
auto backend_one = component_one->mutable_backend();
backend_one->set_registered_name("TestComponent");
component_one->add_resource()->add_part()->set_file_pattern(
"path/to/an/asset.txt");
const string expected_component_one_asset = "new/base/path/to/an/asset.txt";
auto component_two = spec.add_component();
auto backend_two = component_two->mutable_backend();
backend_two->set_registered_name("TestComponent");
// The second component's first resource has no assets.
component_two->add_resource();
// The second component's second resource has one part.
vector<string> expected_component_two_assets;
component_two->add_resource()->add_part()->set_file_pattern(
"another/dir/with/an/asset.txt");
expected_component_two_assets.push_back(
"new/base/another/dir/with/an/asset.txt");
// The second component's third resource has two parts.
auto third_resource = component_two->add_resource();
third_resource->add_part()->set_file_pattern(
"another/dir/with/an/asset3.jif");
expected_component_two_assets.push_back(
"new/base/another/dir/with/an/asset3.jif");
third_resource->add_part()->set_file_pattern(
"another/dir/with/an/asset4.jif");
expected_component_two_assets.push_back(
"new/base/another/dir/with/an/asset4.jif");
LOG(INFO) << spec.DebugString();
string master_spec_str;
spec.SerializeToString(&master_spec_str);
GridPoint hyperparams;
string hyperparams_str;
hyperparams.SerializeToString(&hyperparams_str);
// Create and initialize the kernel under test.
TF_ASSERT_OK(
NodeDefBuilder("get_session", "GetSession")
.Attr("master_spec", master_spec_str)
.Attr("grid_point", hyperparams_str)
.Input(FakeInput(DT_STRING)) // The handle for the ComputeSession.
.Finalize(node_def()));
TF_ASSERT_OK(InitOp());
// Set the input data.
const string container_string = "container_str";
AddInputFromList<string>(TensorShape({1}), {container_string});
// Reset the test context to ensure it's clean.
ResetOpKernelContext();
// Create the string in the resource manager.
std::unique_ptr<string> asset_path_ptr(new string(new_asset_path));
TF_EXPECT_OK(resource_mgr()->Create<StringResource>(
kGlobalContainer, kBasePathTag,
new StringResource(std::move(asset_path_ptr))));
// Run the kernel.
TF_EXPECT_OK(RunOpKernelWithContext());
// Expect that the ResourceMgr contains a ComputeSessionPoolResource.
const string pool_id_str = "pool";
ComputeSessionPoolResource *pool_resource;
TF_EXPECT_OK(resource_mgr()->Lookup<ComputeSessionPoolResource>(
container_string, pool_id_str, &pool_resource));
// Validate that the master spec held by the pool has the new directory names.
auto rewritten_spec = pool_resource->get()->GetSpec();
EXPECT_EQ(rewritten_spec.component(0).resource(0).part(0).file_pattern(),
expected_component_one_asset);
EXPECT_EQ(rewritten_spec.component(1).resource(1).part(0).file_pattern(),
expected_component_two_assets.at(0));
EXPECT_EQ(rewritten_spec.component(1).resource(2).part(0).file_pattern(),
expected_component_two_assets.at(1));
EXPECT_EQ(rewritten_spec.component(1).resource(2).part(1).file_pattern(),
expected_component_two_assets.at(2));
// Unref the managed resources so they get destroyed properly.
pool_resource->Unref();
}
// The GetSessionOp should take a session stored in the resource manager // The GetSessionOp should take a session stored in the resource manager
// and return it to the ComputeSessionPool. // and return it to the ComputeSessionPool.
TEST_F(DragnnOpKernelsTest, ReleaseSessionOpTest) { TEST_F(DragnnOpKernelsTest, ReleaseSessionOpTest) {
...@@ -217,6 +432,56 @@ TEST_F(DragnnOpKernelsTest, ReleaseSessionOpTest) { ...@@ -217,6 +432,56 @@ TEST_F(DragnnOpKernelsTest, ReleaseSessionOpTest) {
EXPECT_EQ(null_resource, nullptr); EXPECT_EQ(null_resource, nullptr);
} }
// The GetSessionCounts op should report the number of sessions created and
// free.
TEST_F(DragnnOpKernelsTest, GetSessionCountsOpTest) {
// Create and initialize the kernel under test.
TF_ASSERT_OK(
NodeDefBuilder("get_session_counts", "GetSessionCounts")
.Input(FakeInput(DT_STRING)) // The handle for the ComputeSession.
.Finalize(node_def()));
TF_ASSERT_OK(InitOp());
// Set the input data.
const string container_string = "container_str";
AddInputFromList<string>(TensorShape({1}), {container_string});
// Reset the test context to ensure it's clean.
ResetOpKernelContext();
// Create a ComputeSessionPool.
MasterSpec spec;
GridPoint hyperparams;
std::unique_ptr<ComputeSessionPool> pool(
new ComputeSessionPool(spec, hyperparams));
// Get an unowned pointer to the ComputeSessionPool before moving
// the pool to the resource manager.
ComputeSessionPool *pool_ptr = pool.get();
TF_ASSERT_OK(resource_mgr()->Create<ComputeSessionPoolResource>(
container_string, "pool",
new ComputeSessionPoolResource(std::move(pool))));
// Create two ComputeSessions.
auto session_one = pool_ptr->GetSession();
auto session_two = pool_ptr->GetSession();
// Retun one of them.
pool_ptr->ReturnSession(std::move(session_two));
// At this point, the pool should report that it has one outstanding session
// and two sessions total.
EXPECT_EQ(1, pool_ptr->num_outstanding_sessions());
EXPECT_EQ(2, pool_ptr->num_unique_sessions());
// Run the kernel.
TF_EXPECT_OK(RunOpKernelWithContext());
EXPECT_EQ(pool_ptr->num_unique_sessions(), GetOutput(0)->vec<int64>()(0));
EXPECT_EQ(pool_ptr->num_outstanding_sessions(),
GetOutput(0)->vec<int64>()(1));
}
// The AdvanceFromOracle op should call AdvanceFromOracle on the specified // The AdvanceFromOracle op should call AdvanceFromOracle on the specified
// component name. // component name.
TEST_F(DragnnOpKernelsTest, AdvanceFromOracleOpTest) { TEST_F(DragnnOpKernelsTest, AdvanceFromOracleOpTest) {
...@@ -287,14 +552,65 @@ TEST_F(DragnnOpKernelsTest, AdvanceFromPredictionOpTest) { ...@@ -287,14 +552,65 @@ TEST_F(DragnnOpKernelsTest, AdvanceFromPredictionOpTest) {
// Set expectations on the mock session. // Set expectations on the mock session.
auto validator_function = [weights](const string &component_name, auto validator_function = [weights](const string &component_name,
const float score_matrix[], const float *score_matrix, int num_items,
int score_matrix_length) { int num_actions) {
EXPECT_EQ(weights.size(), score_matrix_length); EXPECT_EQ(weights.size(), num_items * num_actions);
for (int i = 0; i < weights.size(); ++i) {
EXPECT_EQ(weights[i], score_matrix[i]);
}
return true;
};
EXPECT_CALL(*mock_session_ptr, AdvanceFromPrediction(component_name, _, _, _))
.WillOnce(Invoke(validator_function));
// Run the kernel.
TF_EXPECT_OK(RunOpKernelWithContext());
}
// The AdvanceFromPredicton op should call AdvanceFromPrediction on the
// specified component with the passed scores. If it returns false, the op
// should not return OK.
TEST_F(DragnnOpKernelsTest, AdvanceFromPredictionFailureTest) {
// Create and initialize the kernel under test.
const string component_name = "TESTING_COMPONENT_NAME";
TF_ASSERT_OK(
NodeDefBuilder("advance_from_prediction", "AdvanceFromPrediction")
.Attr("component", component_name)
.Input(FakeInput(DT_STRING)) // The handle for the ComputeSession.
.Input(FakeInput(DT_FLOAT)) // The prediction tensor.
.Finalize(node_def()));
TF_ASSERT_OK(InitOp());
// Set the input data.
const string container_string = "container_str";
const string id_string = "id_str";
AddInputFromList<string>(TensorShape({2}), {container_string, id_string});
const std::vector<float> weights = {1.1, 2.2, 3.3, 4.4};
AddInputFromArray<float>(TensorShape({2, 2}), weights);
// Reset the test context to ensure it's clean.
ResetOpKernelContext();
// Create a MockComputeSession and set expectations.
std::unique_ptr<MockComputeSession> mock_session(new MockComputeSession());
MockComputeSession *mock_session_ptr = mock_session.get();
// Wrap the ComputeSessionResource and put it into the resource manager.
TF_ASSERT_OK(resource_mgr()->Create<ComputeSessionResource>(
container_string, id_string,
new ComputeSessionResource(std::move(mock_session))));
// Set expectations on the mock session.
auto validator_function = [weights](const string &component_name,
const float *score_matrix, int num_items,
int num_actions) {
EXPECT_EQ(weights.size(), num_items * num_actions);
for (int i = 0; i < weights.size(); ++i) { for (int i = 0; i < weights.size(); ++i) {
EXPECT_EQ(weights[i], score_matrix[i]); EXPECT_EQ(weights[i], score_matrix[i]);
} }
return true;
}; };
EXPECT_CALL(*mock_session_ptr, AdvanceFromPrediction(component_name, _, _)) EXPECT_CALL(*mock_session_ptr, AdvanceFromPrediction(component_name, _, _, _))
.WillOnce(Invoke(validator_function)); .WillOnce(Invoke(validator_function));
// Run the kernel. // Run the kernel.
......
...@@ -18,6 +18,20 @@ ...@@ -18,6 +18,20 @@
namespace syntaxnet { namespace syntaxnet {
namespace dragnn { namespace dragnn {
REGISTER_OP("SetAssetDirectory")
.Input("asset_directory: string")
.Output("asset_directory_out: string")
.SetIsStateful()
.Doc(R"doc(
Override the paths to assets specified in the MasterSpec with the given
asset_directory. This op must be called before any calls to GetSession, as it
will create a new session pool with the overridden master spec.
asset_directory: The directory containing all the assets. Note that all assets
must be in a single flat directory.
asset_directory_out: The input, just as an output.
)doc");
REGISTER_OP("GetSession") REGISTER_OP("GetSession")
.Input("container: string") .Input("container: string")
.Attr("master_spec: string") .Attr("master_spec: string")
...@@ -42,6 +56,18 @@ This ComputeSession will no longer be available after this op returns. ...@@ -42,6 +56,18 @@ This ComputeSession will no longer be available after this op returns.
handle: A handle to a ComputeSession that will be returned to the backing pool. handle: A handle to a ComputeSession that will be returned to the backing pool.
)doc"); )doc");
REGISTER_OP("GetSessionCounts")
.Input("container: string")
.Output("stats: int64")
.SetIsStateful()
.Doc(R"doc(
Given a container string, output session counts for that ComputeSessionPool.
container: A unique identifier for the ComputeSessionPool to analyze.
stats: A vector of stats. [0] is the total number of created sessions. [1] is
the number of sessions that are currently not in the pool.
)doc");
REGISTER_OP("InitComponentData") REGISTER_OP("InitComponentData")
.Input("handle: string") .Input("handle: string")
.Input("beam_size: int32") .Input("beam_size: int32")
...@@ -123,28 +149,6 @@ component: The name of a Component instance, matching the ComponentSpec.name. ...@@ -123,28 +149,6 @@ component: The name of a Component instance, matching the ComponentSpec.name.
output_handle: A handle to the same ComputeSession after advancement. output_handle: A handle to the same ComputeSession after advancement.
)doc"); )doc");
REGISTER_OP("DragnnEmbeddingInitializer")
.Output("embeddings: float")
.Attr("embedding_input: string")
.Attr("vocab: string")
.Attr("scaling_coefficient: float = 1.0")
.Attr("seed: int = 0")
.Attr("seed2: int = 0")
.Doc(R"doc(
*** PLACEHOLDER OP - FUNCTIONALITY NOT YET IMPLEMENTED ***
Read embeddings from an an input for every key specified in a text vocab file.
embeddings: A tensor containing embeddings from the specified sstable.
embedding_input: Path to location with embedding vectors.
vocab: Path to list of keys corresponding to the input.
scaling_coefficient: A scaling coefficient for the embedding matrix.
seed: If either `seed` or `seed2` are set to be non-zero, the random number
generator is seeded by the given seed. Otherwise, it is seeded by a
random seed.
seed2: A second seed to avoid seed collision.
)doc");
REGISTER_OP("ExtractFixedFeatures") REGISTER_OP("ExtractFixedFeatures")
.Input("handle: string") .Input("handle: string")
.Output("indices: int32") .Output("indices: int32")
......
...@@ -13,8 +13,8 @@ ...@@ -13,8 +13,8 @@
// limitations under the License. // limitations under the License.
// ============================================================================= // =============================================================================
#ifndef NLP_SAFT_OPENSOURCE_DRAGNN_CORE_RESOURCE_CONTAINER_H_ #ifndef DRAGNN_CORE_RESOURCE_CONTAINER_H_
#define NLP_SAFT_OPENSOURCE_DRAGNN_CORE_RESOURCE_CONTAINER_H_ #define DRAGNN_CORE_RESOURCE_CONTAINER_H_
#include <memory> #include <memory>
...@@ -48,4 +48,4 @@ class ResourceContainer : public tensorflow::ResourceBase { ...@@ -48,4 +48,4 @@ class ResourceContainer : public tensorflow::ResourceBase {
} // namespace dragnn } // namespace dragnn
} // namespace syntaxnet } // namespace syntaxnet
#endif // NLP_SAFT_OPENSOURCE_DRAGNN_CORE_RESOURCE_CONTAINER_H_ #endif // DRAGNN_CORE_RESOURCE_CONTAINER_H_
...@@ -26,6 +26,7 @@ cc_library( ...@@ -26,6 +26,7 @@ cc_library(
deps = [ deps = [
"//dragnn/components/util:bulk_feature_extractor", "//dragnn/components/util:bulk_feature_extractor",
"//dragnn/core:compute_session", "//dragnn/core:compute_session",
"//dragnn/core:input_batch_cache",
"//dragnn/protos:data_proto", "//dragnn/protos:data_proto",
"//dragnn/protos:spec_proto", "//dragnn/protos:spec_proto",
"//syntaxnet:base", "//syntaxnet:base",
......
...@@ -13,8 +13,8 @@ ...@@ -13,8 +13,8 @@
// limitations under the License. // limitations under the License.
// ============================================================================= // =============================================================================
#ifndef NLP_SAFT_OPENSOURCE_DRAGNN_CORE_TEST_GENERIC_H_ #ifndef DRAGNN_CORE_TEST_GENERIC_H_
#define NLP_SAFT_OPENSOURCE_DRAGNN_CORE_TEST_GENERIC_H_ #define DRAGNN_CORE_TEST_GENERIC_H_
#include <utility> #include <utility>
...@@ -31,10 +31,18 @@ MATCHER_P(EqualsProto, a, "Protos are not equivalent:") { ...@@ -31,10 +31,18 @@ MATCHER_P(EqualsProto, a, "Protos are not equivalent:") {
return a.DebugString() == arg.DebugString(); return a.DebugString() == arg.DebugString();
} }
// Matches an error status whose message matches |substr|.
MATCHER_P(IsErrorWithSubstr, substr,
string(negation ? "isn't" : "is") +
" an error Status whose message matches the substring '" +
::testing::PrintToString(substr) + "'") {
return !arg.ok() && arg.error_message().find(substr) != string::npos;
}
// Returns the prefix for where the test data is stored. // Returns the prefix for where the test data is stored.
string GetTestDataPrefix(); string GetTestDataPrefix();
} // namespace test } // namespace test
} // namespace syntaxnet } // namespace syntaxnet
#endif // NLP_SAFT_OPENSOURCE_DRAGNN_CORE_TEST_GENERIC_H_ #endif // DRAGNN_CORE_TEST_GENERIC_H_
...@@ -13,8 +13,8 @@ ...@@ -13,8 +13,8 @@
// limitations under the License. // limitations under the License.
// ============================================================================= // =============================================================================
#ifndef NLP_SAFT_OPENSOURCE_DRAGNN_CORE_TEST_MOCK_COMPONENT_H_ #ifndef DRAGNN_CORE_TEST_MOCK_COMPONENT_H_
#define NLP_SAFT_OPENSOURCE_DRAGNN_CORE_TEST_MOCK_COMPONENT_H_ #define DRAGNN_CORE_TEST_MOCK_COMPONENT_H_
#include <gmock/gmock.h> #include <gmock/gmock.h>
...@@ -47,8 +47,8 @@ class MockComponent : public Component { ...@@ -47,8 +47,8 @@ class MockComponent : public Component {
MOCK_CONST_METHOD3(GetBeamIndexAtStep, MOCK_CONST_METHOD3(GetBeamIndexAtStep,
int(int step, int current_index, int batch)); int(int step, int current_index, int batch));
MOCK_CONST_METHOD2(GetSourceBeamIndex, int(int current_index, int batch)); MOCK_CONST_METHOD2(GetSourceBeamIndex, int(int current_index, int batch));
MOCK_METHOD2(AdvanceFromPrediction, MOCK_METHOD3(AdvanceFromPrediction, bool(const float *transition_matrix,
void(const float transition_matrix[], int matrix_length)); int num_items, int num_actions));
MOCK_METHOD0(AdvanceFromOracle, void()); MOCK_METHOD0(AdvanceFromOracle, void());
MOCK_CONST_METHOD0(IsTerminal, bool()); MOCK_CONST_METHOD0(IsTerminal, bool());
MOCK_METHOD0(GetBeam, std::vector<std::vector<const TransitionState *>>()); MOCK_METHOD0(GetBeam, std::vector<std::vector<const TransitionState *>>());
...@@ -59,6 +59,11 @@ class MockComponent : public Component { ...@@ -59,6 +59,11 @@ class MockComponent : public Component {
int channel_id)); int channel_id));
MOCK_METHOD1(BulkGetFixedFeatures, MOCK_METHOD1(BulkGetFixedFeatures,
int(const BulkFeatureExtractor &extractor)); int(const BulkFeatureExtractor &extractor));
MOCK_METHOD5(BulkEmbedFixedFeatures,
void(int batch_size_padding, int num_steps_padding,
int output_array_size,
const vector<const float *> &per_channel_embeddings,
float *embedding_output));
MOCK_CONST_METHOD1(GetRawLinkFeatures, MOCK_CONST_METHOD1(GetRawLinkFeatures,
std::vector<LinkFeatures>(int channel_id)); std::vector<LinkFeatures>(int channel_id));
MOCK_CONST_METHOD0(GetOracleLabels, std::vector<std::vector<int>>()); MOCK_CONST_METHOD0(GetOracleLabels, std::vector<std::vector<int>>());
...@@ -75,4 +80,4 @@ class MockComponent : public Component { ...@@ -75,4 +80,4 @@ class MockComponent : public Component {
} // namespace dragnn } // namespace dragnn
} // namespace syntaxnet } // namespace syntaxnet
#endif // NLP_SAFT_OPENSOURCE_DRAGNN_CORE_TEST_MOCK_COMPONENT_H_ #endif // DRAGNN_CORE_TEST_MOCK_COMPONENT_H_
...@@ -13,16 +13,18 @@ ...@@ -13,16 +13,18 @@
// limitations under the License. // limitations under the License.
// ============================================================================= // =============================================================================
#ifndef NLP_SAFT_OPENSOURCE_DRAGNN_CORE_TEST_MOCK_COMPUTE_SESSION_H_ #ifndef DRAGNN_CORE_TEST_MOCK_COMPUTE_SESSION_H_
#define NLP_SAFT_OPENSOURCE_DRAGNN_CORE_TEST_MOCK_COMPUTE_SESSION_H_ #define DRAGNN_CORE_TEST_MOCK_COMPUTE_SESSION_H_
#include <gmock/gmock.h> #include <memory>
#include "dragnn/components/util/bulk_feature_extractor.h" #include "dragnn/components/util/bulk_feature_extractor.h"
#include "dragnn/core/compute_session.h" #include "dragnn/core/compute_session.h"
#include "dragnn/core/input_batch_cache.h"
#include "dragnn/protos/data.pb.h" #include "dragnn/protos/data.pb.h"
#include "dragnn/protos/spec.pb.h" #include "dragnn/protos/spec.pb.h"
#include "syntaxnet/base.h" #include "syntaxnet/base.h"
#include <gmock/gmock.h>
#include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/test.h"
namespace syntaxnet { namespace syntaxnet {
...@@ -40,9 +42,9 @@ class MockComputeSession : public ComputeSession { ...@@ -40,9 +42,9 @@ class MockComputeSession : public ComputeSession {
MOCK_METHOD2(SourceComponentBeamSize, MOCK_METHOD2(SourceComponentBeamSize,
int(const string &component_name, int channel_id)); int(const string &component_name, int channel_id));
MOCK_METHOD1(AdvanceFromOracle, void(const string &component_name)); MOCK_METHOD1(AdvanceFromOracle, void(const string &component_name));
MOCK_METHOD3(AdvanceFromPrediction, MOCK_METHOD4(AdvanceFromPrediction,
void(const string &component_name, const float score_matrix[], bool(const string &component_name, const float *score_matrix,
int score_matrix_length)); int num_items, int num_actions));
MOCK_CONST_METHOD5(GetInputFeatures, MOCK_CONST_METHOD5(GetInputFeatures,
int(const string &component_name, int(const string &component_name,
std::function<int32 *(int)> allocate_indices, std::function<int32 *(int)> allocate_indices,
...@@ -52,6 +54,11 @@ class MockComputeSession : public ComputeSession { ...@@ -52,6 +54,11 @@ class MockComputeSession : public ComputeSession {
MOCK_METHOD2(BulkGetInputFeatures, MOCK_METHOD2(BulkGetInputFeatures,
int(const string &component_name, int(const string &component_name,
const BulkFeatureExtractor &extractor)); const BulkFeatureExtractor &extractor));
MOCK_METHOD6(BulkEmbedFixedFeatures,
void(const string &component_name, int batch_size_padding,
int num_steps_padding, int output_array_size,
const vector<const float *> &per_channel_embedding,
float *embedding_output));
MOCK_METHOD2(GetTranslatedLinkFeatures, MOCK_METHOD2(GetTranslatedLinkFeatures,
std::vector<LinkFeatures>(const string &component_name, std::vector<LinkFeatures>(const string &component_name,
int channel_id)); int channel_id));
...@@ -68,9 +75,17 @@ class MockComputeSession : public ComputeSession { ...@@ -68,9 +75,17 @@ class MockComputeSession : public ComputeSession {
MOCK_CONST_METHOD1(GetDescription, string(const string &component_name)); MOCK_CONST_METHOD1(GetDescription, string(const string &component_name));
MOCK_CONST_METHOD1(Translators, const std::vector<const IndexTranslator *>( MOCK_CONST_METHOD1(Translators, const std::vector<const IndexTranslator *>(
const string &component_name)); const string &component_name));
MOCK_CONST_METHOD1(GetReadiedComponent, Component *(const string &name));
// TODO(googleuser): Upgrade gMock to a version that supports mocking methods
// with move-only types, then remove this workaround.
MOCK_METHOD1(DoSetInputBatchCache, void(InputBatchCache *batch));
void SetInputBatchCache(std::unique_ptr<InputBatchCache> batch) override {
DoSetInputBatchCache(batch.get());
}
}; };
} // namespace dragnn } // namespace dragnn
} // namespace syntaxnet } // namespace syntaxnet
#endif // NLP_SAFT_OPENSOURCE_DRAGNN_CORE_TEST_MOCK_COMPUTE_SESSION_H_ #endif // DRAGNN_CORE_TEST_MOCK_COMPUTE_SESSION_H_
...@@ -13,8 +13,8 @@ ...@@ -13,8 +13,8 @@
// limitations under the License. // limitations under the License.
// ============================================================================= // =============================================================================
#ifndef NLP_SAFT_OPENSOURCE_DRAGNN_CORE_TEST_MOCK_TRANSITION_STATE_H_ #ifndef DRAGNN_CORE_TEST_MOCK_TRANSITION_STATE_H_
#define NLP_SAFT_OPENSOURCE_DRAGNN_CORE_TEST_MOCK_TRANSITION_STATE_H_ #define DRAGNN_CORE_TEST_MOCK_TRANSITION_STATE_H_
#include <memory> #include <memory>
...@@ -31,15 +31,17 @@ class MockTransitionState : public TransitionState { ...@@ -31,15 +31,17 @@ class MockTransitionState : public TransitionState {
public: public:
MOCK_METHOD1(Init, void(const TransitionState &parent)); MOCK_METHOD1(Init, void(const TransitionState &parent));
MOCK_CONST_METHOD0(Clone, std::unique_ptr<TransitionState>()); MOCK_CONST_METHOD0(Clone, std::unique_ptr<TransitionState>());
MOCK_CONST_METHOD0(ParentBeamIndex, const int()); MOCK_CONST_METHOD0(ParentBeamIndex, int());
MOCK_METHOD1(SetBeamIndex, void(const int index)); MOCK_METHOD1(SetBeamIndex, void(int index));
MOCK_CONST_METHOD0(GetBeamIndex, const int()); MOCK_CONST_METHOD0(GetBeamIndex, int());
MOCK_CONST_METHOD0(GetScore, const float()); MOCK_CONST_METHOD0(GetScore, float());
MOCK_METHOD1(SetScore, void(const float score)); MOCK_METHOD1(SetScore, void(float score));
MOCK_CONST_METHOD0(IsGold, bool());
MOCK_METHOD1(SetGold, void(bool is_gold));
MOCK_CONST_METHOD0(HTMLRepresentation, string()); MOCK_CONST_METHOD0(HTMLRepresentation, string());
}; };
} // namespace dragnn } // namespace dragnn
} // namespace syntaxnet } // namespace syntaxnet
#endif // NLP_SAFT_OPENSOURCE_DRAGNN_CORE_TEST_MOCK_TRANSITION_STATE_H_ #endif // DRAGNN_CORE_TEST_MOCK_TRANSITION_STATE_H_
...@@ -11,10 +11,6 @@ component { ...@@ -11,10 +11,6 @@ component {
key: "language" key: "language"
value: "en" value: "en"
} }
parameters {
key: "neurosis_feature_syntax_version"
value: "2"
}
parameters { parameters {
key: "parser_skip_deterministic" key: "parser_skip_deterministic"
value: "false" value: "false"
......
...@@ -13,8 +13,8 @@ ...@@ -13,8 +13,8 @@
// limitations under the License. // limitations under the License.
// ============================================================================= // =============================================================================
#ifndef NLP_SAFT_OPENSOURCE_DRAGNN_IO_SENTENCE_INPUT_BATCH_H_ #ifndef DRAGNN_IO_SENTENCE_INPUT_BATCH_H_
#define NLP_SAFT_OPENSOURCE_DRAGNN_IO_SENTENCE_INPUT_BATCH_H_ #define DRAGNN_IO_SENTENCE_INPUT_BATCH_H_
#include <string> #include <string>
#include <vector> #include <vector>
...@@ -35,6 +35,9 @@ class SentenceInputBatch : public InputBatch { ...@@ -35,6 +35,9 @@ class SentenceInputBatch : public InputBatch {
void SetData( void SetData(
const std::vector<string> &stringified_sentence_protos) override; const std::vector<string> &stringified_sentence_protos) override;
// Returns the size of the batch.
int GetSize() const override { return data_.size(); }
// Translates to a vector of stringified Sentence protos. // Translates to a vector of stringified Sentence protos.
const std::vector<string> GetSerializedData() const override; const std::vector<string> GetSerializedData() const override;
...@@ -49,4 +52,4 @@ class SentenceInputBatch : public InputBatch { ...@@ -49,4 +52,4 @@ class SentenceInputBatch : public InputBatch {
} // namespace dragnn } // namespace dragnn
} // namespace syntaxnet } // namespace syntaxnet
#endif // NLP_SAFT_OPENSOURCE_DRAGNN_IO_SENTENCE_INPUT_BATCH_H_ #endif // DRAGNN_IO_SENTENCE_INPUT_BATCH_H_
...@@ -49,6 +49,9 @@ TEST(SentenceInputBatchTest, ConvertsFromStringifiedProtos) { ...@@ -49,6 +49,9 @@ TEST(SentenceInputBatchTest, ConvertsFromStringifiedProtos) {
EXPECT_NE(converted_data->at(i).workspace(), nullptr); EXPECT_NE(converted_data->at(i).workspace(), nullptr);
} }
// Check the batch size.
EXPECT_EQ(strings.size(), set.GetSize());
// Get the data back out. The strings should be identical. // Get the data back out. The strings should be identical.
auto output = set.GetSerializedData(); auto output = set.GetSerializedData();
EXPECT_EQ(output.size(), strings.size()); EXPECT_EQ(output.size(), strings.size());
......
...@@ -13,8 +13,8 @@ ...@@ -13,8 +13,8 @@
// limitations under the License. // limitations under the License.
// ============================================================================= // =============================================================================
#ifndef NLP_SAFT_OPENSOURCE_DRAGNN_IO_SYNTAXNET_SENTENCE_H_ #ifndef DRAGNN_IO_SYNTAXNET_SENTENCE_H_
#define NLP_SAFT_OPENSOURCE_DRAGNN_IO_SYNTAXNET_SENTENCE_H_ #define DRAGNN_IO_SYNTAXNET_SENTENCE_H_
#include "syntaxnet/sentence.pb.h" #include "syntaxnet/sentence.pb.h"
#include "syntaxnet/workspace.h" #include "syntaxnet/workspace.h"
...@@ -39,4 +39,4 @@ class SyntaxNetSentence { ...@@ -39,4 +39,4 @@ class SyntaxNetSentence {
} // namespace dragnn } // namespace dragnn
} // namespace syntaxnet } // namespace syntaxnet
#endif // NLP_SAFT_OPENSOURCE_DRAGNN_IO_SYNTAXNET_SENTENCE_H_ #endif // DRAGNN_IO_SYNTAXNET_SENTENCE_H_
...@@ -26,6 +26,12 @@ tf_proto_library( ...@@ -26,6 +26,12 @@ tf_proto_library(
srcs = ["spec.proto"], srcs = ["spec.proto"],
) )
tf_proto_library(
name = "runtime_proto",
srcs = ["runtime.proto"],
deps = [":spec_proto"],
)
tf_proto_library_py( tf_proto_library_py(
name = "data_py_pb2", name = "data_py_pb2",
srcs = ["data.proto"], srcs = ["data.proto"],
......
syntax = "proto2";
import "dragnn/protos/spec.proto";
package syntaxnet.dragnn.runtime;
// Performance tuning settings that only affect resource usage, not annotated
// output or correctness. This should be attached to the MasterSpec used to
// initialize a Master.
//
// NEXT ID: 2
message MasterPerformanceSettings {
extend MasterSpec {
optional MasterPerformanceSettings master_spec_extension = 160848628;
}
// Maximum size of the free list in the SessionStatePool. NB: The default
// value may occasionally change.
optional uint64 session_state_pool_max_free_states = 1 [default = 4];
}
// As above, but for component-specific performance tuning settings.
//
// NEXT ID: 2
message ComponentPerformanceSettings {
extend ComponentSpec {
optional ComponentPerformanceSettings component_spec_extension = 160999422;
}
// Number of steps to pre-allocate for the relevant component. NB: The
// default value may occasionally change.
optional uint32 pre_allocate_num_steps = 1 [default = 50];
}
// Specification of an ArrayVariableStore.
//
// NEXT ID: 5
message ArrayVariableStoreSpec {
// Characteristics of the variable data. The binary that loads the variables
// must match these characteristics.
optional uint32 version = 1; // required version of the byte array format
optional uint32 alignment_bytes = 2; // required alignment of the byte array
optional bool is_little_endian = 3; // required endian-ness of the byte array
// Variable specifications, in order of appearance in the byte array.
repeated VariableSpec variable = 4;
}
// Specification of a single serialized variable.
//
// NEXT ID: 6
message VariableSpec {
// Formats for serialized pre-trained variables. See VariableStore::Lookup()
// for descriptions of the enumerators.
enum Format {
FORMAT_UNKNOWN = 0;
FORMAT_FLAT = 1;
FORMAT_ROW_MAJOR_MATRIX = 2;
FORMAT_COLUMN_BLOCKED_ROW_MAJOR_MATRIX = 3;
}
// Name of the variable.
optional string name = 1;
// Format of the variable.
optional Format format = 2 [default = FORMAT_UNKNOWN];
// Dimensions of variables. The semantics depends on the format, but is always
// in logical units (number of floats, etc.) rather than bytes,
//
// * flat: single value with the length of the vector
// * row-major and column-major: two values, [rows, columns]
// * row-blocked column-major: three values, [rows, columns, row_block_size]
repeated uint32 dimension = 5;
// Number of sub-views in the AlignedArea that contained the variable.
optional uint64 num_views = 3;
// Sub-view size in bytes for the AlignedArea that contained the variable.
optional uint64 view_size = 4;
}
...@@ -16,6 +16,7 @@ message MasterSpec { ...@@ -16,6 +16,7 @@ message MasterSpec {
// Whether to extract debug traces. // Whether to extract debug traces.
optional bool debug_tracing = 4 [default = false]; optional bool debug_tracing = 4 [default = false];
extensions 1000 to max;
reserved 2, 3, 5; reserved 2, 3, 5;
} }
...@@ -28,8 +29,7 @@ message ComponentSpec { ...@@ -28,8 +29,7 @@ message ComponentSpec {
// TransitionSystem to use. // TransitionSystem to use.
optional RegisteredModuleSpec transition_system = 2; optional RegisteredModuleSpec transition_system = 2;
// Resources that this component depends on. These are copied to TaskInputs // Resources that this component depends on.
// when calling SAFT code.
repeated Resource resource = 3; repeated Resource resource = 3;
// Feature space configurations. // Feature space configurations.
...@@ -58,6 +58,8 @@ message ComponentSpec { ...@@ -58,6 +58,8 @@ message ComponentSpec {
// Default max number of active states for beam inference. // Default max number of active states for beam inference.
optional int32 inference_beam_size = 12 [default = 1]; optional int32 inference_beam_size = 12 [default = 1];
extensions 1000 to max;
} }
// Super generic container for any registered sub-piece of DRAGNN. // Super generic container for any registered sub-piece of DRAGNN.
...@@ -65,14 +67,11 @@ message RegisteredModuleSpec { ...@@ -65,14 +67,11 @@ message RegisteredModuleSpec {
// Name of the registered class. // Name of the registered class.
optional string registered_name = 1; optional string registered_name = 1;
// Parameters to set while initializing this system; these are copied to // Parameters to set while initializing this system.
// Parameters in a TaskSpec when calling SAFT code, or via kwargs in TF Python
// code.
map<string, string> parameters = 2; map<string, string> parameters = 2;
} }
// Fixed resources that will be converted into TaskInput's when calling SAFT // Fixed resource.
// code.
message Resource { message Resource {
optional string name = 1; optional string name = 1;
repeated Part part = 2; repeated Part part = 2;
...@@ -218,6 +217,9 @@ message GridPoint { ...@@ -218,6 +217,9 @@ message GridPoint {
optional double gradient_clip_norm = 11 [default = 0.0]; optional double gradient_clip_norm = 11 [default = 0.0];
// A spec for using multiple optimization methods. // A spec for using multiple optimization methods.
//
// This is not guaranteed to work for recursively-defined composite
// optimizers.
message CompositeOptimizerSpec { message CompositeOptimizerSpec {
// First optimizer. // First optimizer.
optional GridPoint method1 = 1; optional GridPoint method1 = 1;
...@@ -227,6 +229,11 @@ message GridPoint { ...@@ -227,6 +229,11 @@ message GridPoint {
// After this number of steps, switch from first to second. // After this number of steps, switch from first to second.
optional int32 switch_after_steps = 3; optional int32 switch_after_steps = 3;
// Whether to reset the learning rate (which normally decays) after
// switching optimizers. Limitations: It will only reset to the initial
// learning rate, and won't work for recursively-defined optimizers.
optional bool reset_learning_rate = 4 [default = false];
} }
optional CompositeOptimizerSpec composite_optimizer_spec = 12; optional CompositeOptimizerSpec composite_optimizer_spec = 12;
...@@ -247,6 +254,7 @@ message GridPoint { ...@@ -247,6 +254,7 @@ message GridPoint {
// place. Typically a single component. // place. Typically a single component.
optional string self_norm_components_filter = 21; optional string self_norm_components_filter = 21;
extensions 1000 to max;
reserved 5, 6; reserved 5, 6;
} }
......
...@@ -16,6 +16,11 @@ cc_binary( ...@@ -16,6 +16,11 @@ cc_binary(
], ],
) )
filegroup(
name = "testdata",
data = glob(["testdata/**"]),
)
py_library( py_library(
name = "load_dragnn_cc_impl_py", name = "load_dragnn_cc_impl_py",
srcs = ["load_dragnn_cc_impl.py"], srcs = ["load_dragnn_cc_impl.py"],
...@@ -64,7 +69,51 @@ py_library( ...@@ -64,7 +69,51 @@ py_library(
py_library( py_library(
name = "dragnn_ops", name = "dragnn_ops",
srcs = ["dragnn_ops.py"], srcs = ["dragnn_ops.py"],
deps = [], deps = [
":load_dragnn_cc_impl_py",
"//dragnn/core:dragnn_bulk_ops",
"//dragnn/core:dragnn_ops",
"//syntaxnet:load_parser_ops_py",
],
)
py_library(
name = "dragnn_model_saver_lib",
srcs = ["dragnn_model_saver_lib.py"],
deps = [
":dragnn_ops",
":graph_builder",
":load_dragnn_cc_impl_py",
":network_units",
"//dragnn/protos:spec_py_pb2",
"//syntaxnet:load_parser_ops_py",
"//syntaxnet:sentence_py_pb2",
"@org_tensorflow//tensorflow:tensorflow_py",
"@org_tensorflow//tensorflow/core:protos_all_py",
],
)
py_test(
name = "dragnn_model_saver_lib_test",
srcs = ["dragnn_model_saver_lib_test.py"],
data = [":testdata"],
deps = [
":dragnn_model_saver_lib",
"//dragnn/protos:spec_py_pb2",
"@org_tensorflow//tensorflow:tensorflow_py",
],
)
py_binary(
name = "dragnn_model_saver",
srcs = ["dragnn_model_saver.py"],
deps = [
":dragnn_model_saver_lib",
":spec_builder",
"//dragnn/protos:spec_py_pb2",
"@org_tensorflow//tensorflow:tensorflow_py",
"@org_tensorflow//tensorflow/core:protos_all_py",
],
) )
py_library( py_library(
...@@ -76,6 +125,7 @@ py_library( ...@@ -76,6 +125,7 @@ py_library(
":composite_optimizer", ":composite_optimizer",
":dragnn_ops", ":dragnn_ops",
":network_units", ":network_units",
":transformer_units",
":wrapped_units", ":wrapped_units",
"//dragnn/protos:spec_py_pb2", "//dragnn/protos:spec_py_pb2",
"//syntaxnet/util:check", "//syntaxnet/util:check",
...@@ -184,10 +234,7 @@ py_test( ...@@ -184,10 +234,7 @@ py_test(
":bulk_component", ":bulk_component",
":components", ":components",
":dragnn_ops", ":dragnn_ops",
":load_dragnn_cc_impl_py",
":network_units", ":network_units",
"//dragnn/core:dragnn_bulk_ops",
"//dragnn/core:dragnn_ops",
"//dragnn/protos:spec_py_pb2", "//dragnn/protos:spec_py_pb2",
"//syntaxnet:load_parser_ops_py", "//syntaxnet:load_parser_ops_py",
"//syntaxnet:sentence_py_pb2", "//syntaxnet:sentence_py_pb2",
...@@ -201,7 +248,6 @@ py_test( ...@@ -201,7 +248,6 @@ py_test(
srcs = ["composite_optimizer_test.py"], srcs = ["composite_optimizer_test.py"],
deps = [ deps = [
":composite_optimizer", ":composite_optimizer",
":load_dragnn_cc_impl_py",
"//dragnn/core:dragnn_bulk_ops", "//dragnn/core:dragnn_bulk_ops",
"//dragnn/core:dragnn_ops", "//dragnn/core:dragnn_ops",
"//syntaxnet:load_parser_ops_py", "//syntaxnet:load_parser_ops_py",
...@@ -217,15 +263,13 @@ py_test( ...@@ -217,15 +263,13 @@ py_test(
data = [ data = [
"//dragnn/core:testdata", "//dragnn/core:testdata",
], ],
shard_count = 5,
tags = [ tags = [
"notsan", "notsan",
], ],
deps = [ deps = [
":dragnn_ops", ":dragnn_ops",
":graph_builder", ":graph_builder",
":load_dragnn_cc_impl_py",
"//dragnn/core:dragnn_bulk_ops",
"//dragnn/core:dragnn_ops",
"//dragnn/protos:spec_py_pb2", "//dragnn/protos:spec_py_pb2",
"//dragnn/protos:trace_py_pb2", "//dragnn/protos:trace_py_pb2",
"//syntaxnet:load_parser_ops_py", "//syntaxnet:load_parser_ops_py",
...@@ -240,7 +284,6 @@ py_test( ...@@ -240,7 +284,6 @@ py_test(
size = "small", size = "small",
srcs = ["network_units_test.py"], srcs = ["network_units_test.py"],
deps = [ deps = [
":load_dragnn_cc_impl_py",
":network_units", ":network_units",
"//dragnn/core:dragnn_bulk_ops", "//dragnn/core:dragnn_bulk_ops",
"//dragnn/core:dragnn_ops", "//dragnn/core:dragnn_ops",
...@@ -256,6 +299,7 @@ py_test( ...@@ -256,6 +299,7 @@ py_test(
srcs = ["sentence_io_test.py"], srcs = ["sentence_io_test.py"],
data = ["//syntaxnet:testdata"], data = ["//syntaxnet:testdata"],
deps = [ deps = [
":dragnn_ops",
":sentence_io", ":sentence_io",
"//syntaxnet:load_parser_ops_py", "//syntaxnet:load_parser_ops_py",
"//syntaxnet:parser_ops", "//syntaxnet:parser_ops",
...@@ -373,3 +417,30 @@ py_library( ...@@ -373,3 +417,30 @@ py_library(
"@org_tensorflow//tensorflow:tensorflow_py", "@org_tensorflow//tensorflow:tensorflow_py",
], ],
) )
py_library(
name = "transformer_units",
srcs = ["transformer_units.py"],
deps = [
":network_units",
"//syntaxnet/util:check",
"@org_tensorflow//tensorflow:tensorflow_py",
],
)
py_test(
name = "transformer_units_test",
size = "small",
srcs = ["transformer_units_test.py"],
deps = [
":network_units",
":transformer_units",
"//dragnn/core:dragnn_bulk_ops",
"//dragnn/core:dragnn_ops",
"//dragnn/protos:spec_py_pb2",
"//syntaxnet:load_parser_ops_py",
"@org_tensorflow//tensorflow:tensorflow_py",
"@org_tensorflow//tensorflow/core:protos_all_py",
],
)
...@@ -95,7 +95,7 @@ class BiaffineDigraphNetwork(network_units.NetworkUnitInterface): ...@@ -95,7 +95,7 @@ class BiaffineDigraphNetwork(network_units.NetworkUnitInterface):
self._regularized_weights.extend(self._weights) self._regularized_weights.extend(self._weights)
# Negative Layer.dim indicates that the dimension is dynamic. # Negative Layer.dim indicates that the dimension is dynamic.
self._layers.append(network_units.Layer(self, 'adjacency', -1)) self._layers.append(network_units.Layer(component, 'adjacency', -1))
def create(self, def create(self,
fixed_embeddings, fixed_embeddings,
...@@ -209,7 +209,8 @@ class BiaffineLabelNetwork(network_units.NetworkUnitInterface): ...@@ -209,7 +209,8 @@ class BiaffineLabelNetwork(network_units.NetworkUnitInterface):
self._params.extend(self._weights + self._biases) self._params.extend(self._weights + self._biases)
self._regularized_weights.extend(self._weights) self._regularized_weights.extend(self._weights)
self._layers.append(network_units.Layer(self, 'labels', self._num_labels)) self._layers.append(
network_units.Layer(component, 'labels', self._num_labels))
def create(self, def create(self,
fixed_embeddings, fixed_embeddings,
......
...@@ -216,9 +216,11 @@ def build_cross_entropy_loss(logits, gold): ...@@ -216,9 +216,11 @@ def build_cross_entropy_loss(logits, gold):
logits = tf.gather(logits, valid) logits = tf.gather(logits, valid)
correct = tf.reduce_sum(tf.to_int32(tf.nn.in_top_k(logits, gold, 1))) correct = tf.reduce_sum(tf.to_int32(tf.nn.in_top_k(logits, gold, 1)))
total = tf.size(gold) total = tf.size(gold)
cost = tf.reduce_sum( with tf.control_dependencies([tf.assert_positive(total)]):
tf.contrib.nn.deprecated_flipped_sparse_softmax_cross_entropy_with_logits( cost = tf.reduce_sum(
logits, tf.cast(gold, tf.int64))) / tf.cast(total, tf.float32) tf.nn.sparse_softmax_cross_entropy_with_logits(
labels=tf.cast(gold, tf.int64), logits=logits)) / tf.cast(
total, tf.float32)
return cost, correct, total return cost, correct, total
...@@ -267,6 +269,22 @@ class BulkFeatureExtractorComponentBuilder(component.ComponentBuilderBase): ...@@ -267,6 +269,22 @@ class BulkFeatureExtractorComponentBuilder(component.ComponentBuilderBase):
correct, total = tf.constant(0), tf.constant(0) correct, total = tf.constant(0), tf.constant(0)
return state.handle, cost, correct, total return state.handle, cost, correct, total
def build_post_restore_hook(self):
"""Builds a graph that should be executed after the restore op.
This graph is intended to be run once, before the inference pipeline is
run.
Returns:
setup_op - An op that, when run, guarantees all setup ops will run.
"""
logging.info('Building restore hook for component: %s', self.spec.name)
with tf.variable_scope(self.name):
if callable(getattr(self.network, 'build_post_restore_hook', None)):
return [self.network.build_post_restore_hook()]
else:
return []
def build_greedy_inference(self, state, network_states, def build_greedy_inference(self, state, network_states,
during_training=False): during_training=False):
"""Extracts features and advances a batch using the oracle path. """Extracts features and advances a batch using the oracle path.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment