"test/srt/test_torch_compile_moe.py" did not exist on "86fc0d79d0b564fba1c313feafd15323ba731418"
Commit 4364390a authored by Ivan Bogatyy's avatar Ivan Bogatyy Committed by calberti
Browse files

Release DRAGNN bulk networks (#2785)

* Release DRAGNN bulk networks
parent 638fd759
......@@ -80,6 +80,38 @@ pad_to_batch: If set, the op will pad/truncate to this number of elements.
pad_to_steps: If set, the op will pad/truncate to this number of steps.
)doc");
REGISTER_OP("BulkEmbedFixedFeatures")
.Input("handle: string")
.Input("embedding_matrix: num_channels * float")
.Output("output_handle: string")
.Output("embedding_vectors: float")
.Output("num_steps: int32")
.Attr("component: string")
.Attr("num_channels: int")
.Attr("pad_to_batch: int")
.Attr("pad_to_steps: int")
.SetIsStateful()
.Doc(R"doc(
This op is a more efficient version of BulkFixedFeatures.
It is intended to be run with large batch sizes at inference time. The op takes
a handle to ComputeSession and embedding matrices as tensor inputs, and directly
outputs concatenated embedding vectors. It calls the BulkEmbedFixedFeatures
method on the underlying component directly, so it requires a padding vector
to be passed.
handle: A handle to ComputeSession.
embedding_matrix: Embedding matrices.
output_handle: A handle to the same ComputeSession after advancement.
embedding_vectors: (matrix of float) Concatenated embeddings,
shaped as (batch * beam * token) x sum_channel(embedding_dim[channel]).
num_steps: The batch was unrolled for these many steps.
component: The name of a Component instance, matching the ComponentSpec.name.
num_channels: The number of FixedFeature channels.
pad_to_batch: The op will pad/truncate to this number of elements.
pad_to_steps: The op will pad/truncate to this number of steps.
)doc");
REGISTER_OP("BulkAdvanceFromOracle")
.Input("handle: string")
.Output("output_handle: string")
......
......@@ -30,6 +30,7 @@
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/threadpool.h"
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/mutex.h"
......@@ -40,6 +41,8 @@ using tensorflow::DT_INT32;
using tensorflow::DT_INT64;
using tensorflow::DT_STRING;
using tensorflow::DataType;
using tensorflow::io::Dirname;
using tensorflow::io::JoinPath;
using tensorflow::OpKernel;
using tensorflow::OpKernelConstruction;
using tensorflow::OpKernelContext;
......@@ -53,6 +56,59 @@ namespace dragnn {
typedef ResourceContainer<ComputeSession> ComputeSessionResource;
typedef ResourceContainer<ComputeSessionPool> ComputeSessionPoolResource;
typedef ResourceContainer<string> StringResource;
namespace {
const char kGlobalContainer[] = "__reserved_global_container";
const char kBasePathTag[] = "__reserved_asset_base_path";
const char kUnmanagedAssetDirectory[] = "assets.extra";
// When restoring a graph from a SavedModel, this op will rewrite the MasterSpec
// to point the DRAGNN components to the new resource locations. It will then
// add a string resource to the resource manager, which will be used to
// rebuild the masterspec before it is acquired in the GetComputeSession op.
class SetAssetDirectory : public OpKernel {
public:
explicit SetAssetDirectory(OpKernelConstruction *context)
: OpKernel(context) {
OP_REQUIRES_OK(context, context->MatchSignature({DT_STRING}, {DT_STRING}));
}
void Compute(OpKernelContext *context) override {
ResourceMgr *rmgr = context->resource_manager();
const string asset_path = context->input(0).scalar<string>()();
// TODO(googleuser): Get this data in a way that isn't fragile as all hell.
// "I've done stuff I ain't proud of... and the stuff I am proud of is
// disgusting." -- Moe
auto extra_asset_dir =
JoinPath(Dirname(Dirname(asset_path)), kUnmanagedAssetDirectory);
LOG(INFO) << "Found extra assets path at:" << extra_asset_dir;
// Rather than attempt to rewrite the MasterSpec here, we save off a
// StringResource containing the new asset path. It will be used in
// the GetSession op, if it exists.
std::unique_ptr<string> asset_path_ptr(new string(extra_asset_dir));
OP_REQUIRES_OK(context, rmgr->Create<StringResource>(
kGlobalContainer, kBasePathTag,
new StringResource(std::move(asset_path_ptr))));
// This isn't used anywhere - it just allows us to have an output so that
// it's easier to reason about Tensorflow's graph execution.
Tensor *output;
OP_REQUIRES_OK(context,
context->allocate_output(0, TensorShape({1}), &output));
output->vec<string>()(0) = asset_path;
}
private:
TF_DISALLOW_COPY_AND_ASSIGN(SetAssetDirectory);
};
REGISTER_KERNEL_BUILDER(Name("SetAssetDirectory").Device(DEVICE_CPU),
SetAssetDirectory);
// Given a MasterSpec proto, outputs a handle to a ComputeSession.
class GetSession : public OpKernel {
......@@ -66,6 +122,7 @@ class GetSession : public OpKernel {
CHECK(master_spec_.ParseFromString(master_spec_str));
CHECK(grid_point_.ParseFromString(grid_point_spec_str));
OP_REQUIRES_OK(context, context->MatchSignature({DT_STRING}, {DT_STRING}));
has_overwritten_spec_ = false;
}
void Compute(OpKernelContext *context) override {
......@@ -74,10 +131,32 @@ class GetSession : public OpKernel {
// Create the pool for this container, or re-use one that was allocated in a
// previous call.
auto create_pool = [this,
auto create_pool = [this, &rmgr,
&container](ComputeSessionPoolResource **resource) {
LOG(INFO) << "Creating new ComputeSessionPool in container handle: "
<< container;
if (has_overwritten_spec_) {
// TODO(googleuser): Figure out a way to test this.
// If there's already an overwritten spec, use that.
LOG(INFO) << "Creating new ComputeSessionPool in container handle: "
<< container << " with previously overwritten master spec.";
} else {
// If not, try to find the resource base.
StringResource *resource_base;
auto resource_base_lookup = rmgr->Lookup<StringResource>(
kGlobalContainer, kBasePathTag, &resource_base);
if (resource_base_lookup.ok()) {
// If that exists, the spec must be rewritten.
string resource_base_path = *resource_base->get();
LOG(INFO) << "Creating new ComputeSessionPool in container handle: "
<< container << " using resource directory base "
<< resource_base_path;
RewriteMasterSpec(resource_base_path);
resource_base->Unref();
} else {
// If not, just use the spec as is.
LOG(INFO) << "Creating new ComputeSessionPool in container handle: "
<< container << " without editing master spec.";
}
}
std::unique_ptr<ComputeSessionPool> pool(
new ComputeSessionPool(master_spec_, grid_point_));
*resource = new ComputeSessionPoolResource(std::move(pool));
......@@ -120,6 +199,23 @@ class GetSession : public OpKernel {
}
private:
// Rewrites this op's saved MasterSpec, appending the new base directory.
void RewriteMasterSpec(const string &new_base) {
for (auto &component_spec : *master_spec_.mutable_component()) {
for (auto &resource_def : *component_spec.mutable_resource()) {
for (auto &part_def : *resource_def.mutable_part()) {
part_def.set_file_pattern(
JoinPath(new_base, part_def.file_pattern()));
VLOG(2) << "New path: " << part_def.file_pattern();
}
}
}
VLOG(3) << "Rewritten spec: " << master_spec_.DebugString();
has_overwritten_spec_ = true;
}
bool has_overwritten_spec_;
MasterSpec master_spec_;
GridPoint grid_point_;
......@@ -141,7 +237,6 @@ REGISTER_KERNEL_BUILDER(Name("GetSession").Device(DEVICE_CPU), GetSession);
class ReleaseSession : public OpKernel {
public:
explicit ReleaseSession(OpKernelConstruction *context) : OpKernel(context) {
string master_spec_str;
OP_REQUIRES_OK(context, context->MatchSignature({DT_STRING}, {}));
}
......@@ -188,6 +283,53 @@ class ReleaseSession : public OpKernel {
REGISTER_KERNEL_BUILDER(Name("ReleaseSession").Device(DEVICE_CPU),
ReleaseSession);
// Returns statistics about session loads to the graph. This op returns the
// total number of created Session objects and the number of those objects
// that are currently being used in the ComputeSessionPool.
class GetSessionCounts : public OpKernel {
public:
explicit GetSessionCounts(OpKernelConstruction *context) : OpKernel(context) {
OP_REQUIRES_OK(context, context->MatchSignature({DT_STRING}, {DT_INT64}));
}
void Compute(OpKernelContext *context) override {
const string container = context->input(0).scalar<string>()();
VLOG(1) << "Getting stats for container: " << container;
ResourceMgr *rmgr = context->resource_manager();
// Allocate the output tensors.
Tensor *output;
OP_REQUIRES_OK(context,
context->allocate_output(0, TensorShape({2}), &output));
// Get the pool for this container.
ComputeSessionPoolResource *pool_resource;
auto result = rmgr->Lookup<ComputeSessionPoolResource>(container, "pool",
&pool_resource);
if (!result.ok()) {
// If there's no ComputeSessionPoolResource, report 0 sessions created
// and 0 available.
output->vec<int64>()(0) = 0;
output->vec<int64>()(1) = 0;
return;
}
auto *pool = pool_resource->get();
CHECK(pool != nullptr);
output->vec<int64>()(0) = pool->num_unique_sessions();
output->vec<int64>()(1) = pool->num_outstanding_sessions();
pool_resource->Unref();
}
private:
TF_DISALLOW_COPY_AND_ASSIGN(GetSessionCounts);
};
REGISTER_KERNEL_BUILDER(Name("GetSessionCounts").Device(DEVICE_CPU),
GetSessionCounts);
/*******************************************************************************
* ComputeSessionOps below here.
******************************************************************************/
......@@ -233,9 +375,17 @@ class AdvanceFromPrediction : public ComputeSessionOp {
void ComputeWithState(OpKernelContext *context,
ComputeSession *session) override {
const Tensor &scores = context->input(1);
session->AdvanceFromPrediction(component_name(),
scores.tensor<float, 2>().data(),
scores.NumElements());
const int num_items = scores.shape().dim_size(0);
const int num_actions = scores.shape().dim_size(1);
bool success = session->AdvanceFromPrediction(
component_name(), scores.tensor<float, 2>().data(), num_items,
num_actions);
if (success) {
VLOG(2) << "Score: " << scores.tensor<float, 2>();
}
OP_REQUIRES(
context, success,
tensorflow::errors::Internal("Unable to advance from prediction."));
}
private:
......@@ -247,13 +397,12 @@ REGISTER_KERNEL_BUILDER(Name("AdvanceFromPrediction").Device(DEVICE_CPU),
// Given a handle to a ComputeSession and a channel index, outputs fixed
// features.
// Fixed features are returned as 3 vectors or equal length:
// Fixed features are returned as 3 vectors of equal length:
// - ids: specifies which rows should be looked up in the embedding
// matrix,
// - weights: specifies a scale for each embedding vector,
// - indices: sorted vector that assigns the same index to embedding
// vectors
// that should be summed together.
// vectors that should be summed together.
//
// For example if we have 3 features, for a given channel, we might have:
// feature a: (5, 1)
......@@ -300,7 +449,10 @@ class ExtractFixedFeatures : public ComputeSessionOp {
int num_features = session->GetInputFeatures(
component_name(), indices_allocator, ids_allocator, weights_allocator,
channel_id_);
VLOG(2) << "Extracted " << num_features;
VLOG(2) << "Extracted features (" << num_features << "): "
<< " ids=" << context->mutable_output(1)->vec<int64>()
<< " weights=" << context->mutable_output(2)->vec<float>()
<< " indices=" << context->mutable_output(0)->vec<int32>();
}
private:
......@@ -524,6 +676,7 @@ class AttachDataReader : public ComputeSessionOp {
auto input_data(context->input(1).vec<string>());
std::vector<string> data;
data.reserve(input_data.size());
for (int i = 0; i < input_data.size(); ++i) {
data.push_back(input_data(i));
}
......@@ -642,5 +795,6 @@ class GetComponentTrace : public ComputeSessionOp {
REGISTER_KERNEL_BUILDER(Name("GetComponentTrace").Device(DEVICE_CPU),
GetComponentTrace);
} // namespace
} // namespace dragnn
} // namespace syntaxnet
......@@ -17,6 +17,7 @@
#include <memory>
#include <vector>
#include "dragnn/core/component_registry.h"
#include "dragnn/core/compute_session.h"
#include "dragnn/core/compute_session_pool.h"
#include "dragnn/core/resource_container.h"
......@@ -66,6 +67,87 @@ using testing::Return;
typedef ResourceContainer<ComputeSession> ComputeSessionResource;
typedef ResourceContainer<ComputeSessionPool> ComputeSessionPoolResource;
typedef ResourceContainer<string> StringResource;
namespace {
const char kGlobalContainer[] = "__reserved_global_container";
const char kBasePathTag[] = "__reserved_asset_base_path";
const char kUnmanagedAssetDirectory[] = "assets.extra";
} // namespace
// Define a test component to validate registered construction.
class TestComponent : public Component {
public:
TestComponent() {}
void InitializeComponent(const ComponentSpec &spec) override {
name_ = spec.name();
}
void InitializeData(
const std::vector<std::vector<const TransitionState *>> &states,
int max_beam_size, InputBatchCache *input_data) override {}
void InitializeTracing() override {}
void DisableTracing() override {}
bool IsReady() const override { return true; }
string Name() const override { return name_; }
int BeamSize() const override { return 3; }
int BatchSize() const override { return 1; }
int StepsTaken(int batch_index) const override { return 0; }
int GetBeamIndexAtStep(int step, int current_index,
int batch) const override {
return 0;
}
int GetSourceBeamIndex(int current_index, int batch) const override {
return 0;
}
bool AdvanceFromPrediction(const float *score_matrix, int num_items,
int num_actions) override {
return true;
}
void AdvanceFromOracle() override {}
bool IsTerminal() const override { return true; }
std::function<int(int, int, int)> GetStepLookupFunction(
const string &method) override {
return nullptr;
}
std::vector<std::vector<const TransitionState *>> GetBeam() override {
std::vector<std::vector<const TransitionState *>> states;
return states;
}
int GetFixedFeatures(std::function<int32 *(int)> allocate_indices,
std::function<int64 *(int)> allocate_ids,
std::function<float *(int)> allocate_weights,
int channel_id) const override {
return 0;
}
int BulkGetFixedFeatures(const BulkFeatureExtractor &extractor) override {
return 0;
}
void BulkEmbedFixedFeatures(
int batch_size_padding, int num_steps_padding, int output_array_size,
const vector<const float *> &per_channel_embeddings,
float *embedding_matrix) override {}
std::vector<LinkFeatures> GetRawLinkFeatures(int channel_id) const override {
std::vector<LinkFeatures> ret;
return ret;
}
std::vector<std::vector<int>> GetOracleLabels() const override {
std::vector<std::vector<int>> ret;
return ret;
}
void FinalizeData() override {}
void ResetComponent() override {}
std::vector<std::vector<ComponentTrace>> GetTraceProtos() const override {
std::vector<std::vector<ComponentTrace>> ret;
return ret;
}
void AddTranslatedLinkFeaturesToTrace(
const std::vector<LinkFeatures> &features, int channel_id) override {}
string name_;
};
REGISTER_DRAGNN_COMPONENT(TestComponent);
class DragnnOpKernelsTest : public tensorflow::OpsTestBase {
public:
......@@ -106,6 +188,42 @@ LinkFeatures MakeFeatures(int batch_index, int beam_index, int step) {
return features;
}
// The SetAssetDirectory op should
// 1. When given an asset path (foo/bar/baz/asset/thing), strip the path to
// foo/bar/baz and add 'assets.extra' to it.
// 2. Store that path in the resource manager.
TEST_F(DragnnOpKernelsTest, SetAssetDirectoryTest) {
// Create a MasterSpec and GridPoint string to pass into the attrs for this
// op.
const string new_asset_path = "new/directory/path/asset/master_spec";
const string expected_asset_path =
StrCat("new/directory/path/", kUnmanagedAssetDirectory);
// Create and initialize the kernel under test.
TF_ASSERT_OK(NodeDefBuilder("set_asset_directory", "SetAssetDirectory")
.Input(FakeInput(DT_STRING)) // The new asset path.
.Finalize(node_def()));
TF_ASSERT_OK(InitOp());
// Set the input data.
AddInputFromList<string>(TensorShape({1}), {new_asset_path});
// Reset the test context to ensure it's clean.
ResetOpKernelContext();
// Run the kernel.
TF_EXPECT_OK(RunOpKernelWithContext());
// Expect that the ResourceMgr contains a the correct string.
StringResource *resource;
TF_EXPECT_OK(resource_mgr()->Lookup<StringResource>(kGlobalContainer,
kBasePathTag, &resource));
EXPECT_EQ(*resource->get(), expected_asset_path);
resource->Unref();
}
// The GetSessionOp should
// 1. create a ComputeSessionPool resource and store it in the ResourceMgr,
// 2. create a ComputeSession resource and store it in the ResourceMgr,
......@@ -164,6 +282,103 @@ TEST_F(DragnnOpKernelsTest, GetSessionOpTest) {
pool_resource->Unref();
}
// If an asset_base_path resource exists, the GetSession op should prepend
// that path to all paths in the MasterSpec before creating a session.
TEST_F(DragnnOpKernelsTest, GetSessionWithAssetBasePathTest) {
// Create a MasterSpec and GridPoint string to pass into the attrs for this
// op.
const string new_asset_path = "new/base";
MasterSpec spec;
// The first component in the MasterSpec has one resource with one part.
auto component_one = spec.add_component();
auto backend_one = component_one->mutable_backend();
backend_one->set_registered_name("TestComponent");
component_one->add_resource()->add_part()->set_file_pattern(
"path/to/an/asset.txt");
const string expected_component_one_asset = "new/base/path/to/an/asset.txt";
auto component_two = spec.add_component();
auto backend_two = component_two->mutable_backend();
backend_two->set_registered_name("TestComponent");
// The second component's first resource has no assets.
component_two->add_resource();
// The second component's second resource has one part.
vector<string> expected_component_two_assets;
component_two->add_resource()->add_part()->set_file_pattern(
"another/dir/with/an/asset.txt");
expected_component_two_assets.push_back(
"new/base/another/dir/with/an/asset.txt");
// The second component's third resource has two parts.
auto third_resource = component_two->add_resource();
third_resource->add_part()->set_file_pattern(
"another/dir/with/an/asset3.jif");
expected_component_two_assets.push_back(
"new/base/another/dir/with/an/asset3.jif");
third_resource->add_part()->set_file_pattern(
"another/dir/with/an/asset4.jif");
expected_component_two_assets.push_back(
"new/base/another/dir/with/an/asset4.jif");
LOG(INFO) << spec.DebugString();
string master_spec_str;
spec.SerializeToString(&master_spec_str);
GridPoint hyperparams;
string hyperparams_str;
hyperparams.SerializeToString(&hyperparams_str);
// Create and initialize the kernel under test.
TF_ASSERT_OK(
NodeDefBuilder("get_session", "GetSession")
.Attr("master_spec", master_spec_str)
.Attr("grid_point", hyperparams_str)
.Input(FakeInput(DT_STRING)) // The handle for the ComputeSession.
.Finalize(node_def()));
TF_ASSERT_OK(InitOp());
// Set the input data.
const string container_string = "container_str";
AddInputFromList<string>(TensorShape({1}), {container_string});
// Reset the test context to ensure it's clean.
ResetOpKernelContext();
// Create the string in the resource manager.
std::unique_ptr<string> asset_path_ptr(new string(new_asset_path));
TF_EXPECT_OK(resource_mgr()->Create<StringResource>(
kGlobalContainer, kBasePathTag,
new StringResource(std::move(asset_path_ptr))));
// Run the kernel.
TF_EXPECT_OK(RunOpKernelWithContext());
// Expect that the ResourceMgr contains a ComputeSessionPoolResource.
const string pool_id_str = "pool";
ComputeSessionPoolResource *pool_resource;
TF_EXPECT_OK(resource_mgr()->Lookup<ComputeSessionPoolResource>(
container_string, pool_id_str, &pool_resource));
// Validate that the master spec held by the pool has the new directory names.
auto rewritten_spec = pool_resource->get()->GetSpec();
EXPECT_EQ(rewritten_spec.component(0).resource(0).part(0).file_pattern(),
expected_component_one_asset);
EXPECT_EQ(rewritten_spec.component(1).resource(1).part(0).file_pattern(),
expected_component_two_assets.at(0));
EXPECT_EQ(rewritten_spec.component(1).resource(2).part(0).file_pattern(),
expected_component_two_assets.at(1));
EXPECT_EQ(rewritten_spec.component(1).resource(2).part(1).file_pattern(),
expected_component_two_assets.at(2));
// Unref the managed resources so they get destroyed properly.
pool_resource->Unref();
}
// The GetSessionOp should take a session stored in the resource manager
// and return it to the ComputeSessionPool.
TEST_F(DragnnOpKernelsTest, ReleaseSessionOpTest) {
......@@ -217,6 +432,56 @@ TEST_F(DragnnOpKernelsTest, ReleaseSessionOpTest) {
EXPECT_EQ(null_resource, nullptr);
}
// The GetSessionCounts op should report the number of sessions created and
// free.
TEST_F(DragnnOpKernelsTest, GetSessionCountsOpTest) {
// Create and initialize the kernel under test.
TF_ASSERT_OK(
NodeDefBuilder("get_session_counts", "GetSessionCounts")
.Input(FakeInput(DT_STRING)) // The handle for the ComputeSession.
.Finalize(node_def()));
TF_ASSERT_OK(InitOp());
// Set the input data.
const string container_string = "container_str";
AddInputFromList<string>(TensorShape({1}), {container_string});
// Reset the test context to ensure it's clean.
ResetOpKernelContext();
// Create a ComputeSessionPool.
MasterSpec spec;
GridPoint hyperparams;
std::unique_ptr<ComputeSessionPool> pool(
new ComputeSessionPool(spec, hyperparams));
// Get an unowned pointer to the ComputeSessionPool before moving
// the pool to the resource manager.
ComputeSessionPool *pool_ptr = pool.get();
TF_ASSERT_OK(resource_mgr()->Create<ComputeSessionPoolResource>(
container_string, "pool",
new ComputeSessionPoolResource(std::move(pool))));
// Create two ComputeSessions.
auto session_one = pool_ptr->GetSession();
auto session_two = pool_ptr->GetSession();
// Retun one of them.
pool_ptr->ReturnSession(std::move(session_two));
// At this point, the pool should report that it has one outstanding session
// and two sessions total.
EXPECT_EQ(1, pool_ptr->num_outstanding_sessions());
EXPECT_EQ(2, pool_ptr->num_unique_sessions());
// Run the kernel.
TF_EXPECT_OK(RunOpKernelWithContext());
EXPECT_EQ(pool_ptr->num_unique_sessions(), GetOutput(0)->vec<int64>()(0));
EXPECT_EQ(pool_ptr->num_outstanding_sessions(),
GetOutput(0)->vec<int64>()(1));
}
// The AdvanceFromOracle op should call AdvanceFromOracle on the specified
// component name.
TEST_F(DragnnOpKernelsTest, AdvanceFromOracleOpTest) {
......@@ -287,14 +552,65 @@ TEST_F(DragnnOpKernelsTest, AdvanceFromPredictionOpTest) {
// Set expectations on the mock session.
auto validator_function = [weights](const string &component_name,
const float score_matrix[],
int score_matrix_length) {
EXPECT_EQ(weights.size(), score_matrix_length);
const float *score_matrix, int num_items,
int num_actions) {
EXPECT_EQ(weights.size(), num_items * num_actions);
for (int i = 0; i < weights.size(); ++i) {
EXPECT_EQ(weights[i], score_matrix[i]);
}
return true;
};
EXPECT_CALL(*mock_session_ptr, AdvanceFromPrediction(component_name, _, _, _))
.WillOnce(Invoke(validator_function));
// Run the kernel.
TF_EXPECT_OK(RunOpKernelWithContext());
}
// The AdvanceFromPredicton op should call AdvanceFromPrediction on the
// specified component with the passed scores. If it returns false, the op
// should not return OK.
TEST_F(DragnnOpKernelsTest, AdvanceFromPredictionFailureTest) {
// Create and initialize the kernel under test.
const string component_name = "TESTING_COMPONENT_NAME";
TF_ASSERT_OK(
NodeDefBuilder("advance_from_prediction", "AdvanceFromPrediction")
.Attr("component", component_name)
.Input(FakeInput(DT_STRING)) // The handle for the ComputeSession.
.Input(FakeInput(DT_FLOAT)) // The prediction tensor.
.Finalize(node_def()));
TF_ASSERT_OK(InitOp());
// Set the input data.
const string container_string = "container_str";
const string id_string = "id_str";
AddInputFromList<string>(TensorShape({2}), {container_string, id_string});
const std::vector<float> weights = {1.1, 2.2, 3.3, 4.4};
AddInputFromArray<float>(TensorShape({2, 2}), weights);
// Reset the test context to ensure it's clean.
ResetOpKernelContext();
// Create a MockComputeSession and set expectations.
std::unique_ptr<MockComputeSession> mock_session(new MockComputeSession());
MockComputeSession *mock_session_ptr = mock_session.get();
// Wrap the ComputeSessionResource and put it into the resource manager.
TF_ASSERT_OK(resource_mgr()->Create<ComputeSessionResource>(
container_string, id_string,
new ComputeSessionResource(std::move(mock_session))));
// Set expectations on the mock session.
auto validator_function = [weights](const string &component_name,
const float *score_matrix, int num_items,
int num_actions) {
EXPECT_EQ(weights.size(), num_items * num_actions);
for (int i = 0; i < weights.size(); ++i) {
EXPECT_EQ(weights[i], score_matrix[i]);
}
return true;
};
EXPECT_CALL(*mock_session_ptr, AdvanceFromPrediction(component_name, _, _))
EXPECT_CALL(*mock_session_ptr, AdvanceFromPrediction(component_name, _, _, _))
.WillOnce(Invoke(validator_function));
// Run the kernel.
......
......@@ -18,6 +18,20 @@
namespace syntaxnet {
namespace dragnn {
REGISTER_OP("SetAssetDirectory")
.Input("asset_directory: string")
.Output("asset_directory_out: string")
.SetIsStateful()
.Doc(R"doc(
Override the paths to assets specified in the MasterSpec with the given
asset_directory. This op must be called before any calls to GetSession, as it
will create a new session pool with the overridden master spec.
asset_directory: The directory containing all the assets. Note that all assets
must be in a single flat directory.
asset_directory_out: The input, just as an output.
)doc");
REGISTER_OP("GetSession")
.Input("container: string")
.Attr("master_spec: string")
......@@ -42,6 +56,18 @@ This ComputeSession will no longer be available after this op returns.
handle: A handle to a ComputeSession that will be returned to the backing pool.
)doc");
REGISTER_OP("GetSessionCounts")
.Input("container: string")
.Output("stats: int64")
.SetIsStateful()
.Doc(R"doc(
Given a container string, output session counts for that ComputeSessionPool.
container: A unique identifier for the ComputeSessionPool to analyze.
stats: A vector of stats. [0] is the total number of created sessions. [1] is
the number of sessions that are currently not in the pool.
)doc");
REGISTER_OP("InitComponentData")
.Input("handle: string")
.Input("beam_size: int32")
......@@ -123,28 +149,6 @@ component: The name of a Component instance, matching the ComponentSpec.name.
output_handle: A handle to the same ComputeSession after advancement.
)doc");
REGISTER_OP("DragnnEmbeddingInitializer")
.Output("embeddings: float")
.Attr("embedding_input: string")
.Attr("vocab: string")
.Attr("scaling_coefficient: float = 1.0")
.Attr("seed: int = 0")
.Attr("seed2: int = 0")
.Doc(R"doc(
*** PLACEHOLDER OP - FUNCTIONALITY NOT YET IMPLEMENTED ***
Read embeddings from an an input for every key specified in a text vocab file.
embeddings: A tensor containing embeddings from the specified sstable.
embedding_input: Path to location with embedding vectors.
vocab: Path to list of keys corresponding to the input.
scaling_coefficient: A scaling coefficient for the embedding matrix.
seed: If either `seed` or `seed2` are set to be non-zero, the random number
generator is seeded by the given seed. Otherwise, it is seeded by a
random seed.
seed2: A second seed to avoid seed collision.
)doc");
REGISTER_OP("ExtractFixedFeatures")
.Input("handle: string")
.Output("indices: int32")
......
......@@ -13,8 +13,8 @@
// limitations under the License.
// =============================================================================
#ifndef NLP_SAFT_OPENSOURCE_DRAGNN_CORE_RESOURCE_CONTAINER_H_
#define NLP_SAFT_OPENSOURCE_DRAGNN_CORE_RESOURCE_CONTAINER_H_
#ifndef DRAGNN_CORE_RESOURCE_CONTAINER_H_
#define DRAGNN_CORE_RESOURCE_CONTAINER_H_
#include <memory>
......@@ -48,4 +48,4 @@ class ResourceContainer : public tensorflow::ResourceBase {
} // namespace dragnn
} // namespace syntaxnet
#endif // NLP_SAFT_OPENSOURCE_DRAGNN_CORE_RESOURCE_CONTAINER_H_
#endif // DRAGNN_CORE_RESOURCE_CONTAINER_H_
......@@ -26,6 +26,7 @@ cc_library(
deps = [
"//dragnn/components/util:bulk_feature_extractor",
"//dragnn/core:compute_session",
"//dragnn/core:input_batch_cache",
"//dragnn/protos:data_proto",
"//dragnn/protos:spec_proto",
"//syntaxnet:base",
......
......@@ -13,8 +13,8 @@
// limitations under the License.
// =============================================================================
#ifndef NLP_SAFT_OPENSOURCE_DRAGNN_CORE_TEST_GENERIC_H_
#define NLP_SAFT_OPENSOURCE_DRAGNN_CORE_TEST_GENERIC_H_
#ifndef DRAGNN_CORE_TEST_GENERIC_H_
#define DRAGNN_CORE_TEST_GENERIC_H_
#include <utility>
......@@ -31,10 +31,18 @@ MATCHER_P(EqualsProto, a, "Protos are not equivalent:") {
return a.DebugString() == arg.DebugString();
}
// Matches an error status whose message matches |substr|.
MATCHER_P(IsErrorWithSubstr, substr,
string(negation ? "isn't" : "is") +
" an error Status whose message matches the substring '" +
::testing::PrintToString(substr) + "'") {
return !arg.ok() && arg.error_message().find(substr) != string::npos;
}
// Returns the prefix for where the test data is stored.
string GetTestDataPrefix();
} // namespace test
} // namespace syntaxnet
#endif // NLP_SAFT_OPENSOURCE_DRAGNN_CORE_TEST_GENERIC_H_
#endif // DRAGNN_CORE_TEST_GENERIC_H_
......@@ -13,8 +13,8 @@
// limitations under the License.
// =============================================================================
#ifndef NLP_SAFT_OPENSOURCE_DRAGNN_CORE_TEST_MOCK_COMPONENT_H_
#define NLP_SAFT_OPENSOURCE_DRAGNN_CORE_TEST_MOCK_COMPONENT_H_
#ifndef DRAGNN_CORE_TEST_MOCK_COMPONENT_H_
#define DRAGNN_CORE_TEST_MOCK_COMPONENT_H_
#include <gmock/gmock.h>
......@@ -47,8 +47,8 @@ class MockComponent : public Component {
MOCK_CONST_METHOD3(GetBeamIndexAtStep,
int(int step, int current_index, int batch));
MOCK_CONST_METHOD2(GetSourceBeamIndex, int(int current_index, int batch));
MOCK_METHOD2(AdvanceFromPrediction,
void(const float transition_matrix[], int matrix_length));
MOCK_METHOD3(AdvanceFromPrediction, bool(const float *transition_matrix,
int num_items, int num_actions));
MOCK_METHOD0(AdvanceFromOracle, void());
MOCK_CONST_METHOD0(IsTerminal, bool());
MOCK_METHOD0(GetBeam, std::vector<std::vector<const TransitionState *>>());
......@@ -59,6 +59,11 @@ class MockComponent : public Component {
int channel_id));
MOCK_METHOD1(BulkGetFixedFeatures,
int(const BulkFeatureExtractor &extractor));
MOCK_METHOD5(BulkEmbedFixedFeatures,
void(int batch_size_padding, int num_steps_padding,
int output_array_size,
const vector<const float *> &per_channel_embeddings,
float *embedding_output));
MOCK_CONST_METHOD1(GetRawLinkFeatures,
std::vector<LinkFeatures>(int channel_id));
MOCK_CONST_METHOD0(GetOracleLabels, std::vector<std::vector<int>>());
......@@ -75,4 +80,4 @@ class MockComponent : public Component {
} // namespace dragnn
} // namespace syntaxnet
#endif // NLP_SAFT_OPENSOURCE_DRAGNN_CORE_TEST_MOCK_COMPONENT_H_
#endif // DRAGNN_CORE_TEST_MOCK_COMPONENT_H_
......@@ -13,16 +13,18 @@
// limitations under the License.
// =============================================================================
#ifndef NLP_SAFT_OPENSOURCE_DRAGNN_CORE_TEST_MOCK_COMPUTE_SESSION_H_
#define NLP_SAFT_OPENSOURCE_DRAGNN_CORE_TEST_MOCK_COMPUTE_SESSION_H_
#ifndef DRAGNN_CORE_TEST_MOCK_COMPUTE_SESSION_H_
#define DRAGNN_CORE_TEST_MOCK_COMPUTE_SESSION_H_
#include <gmock/gmock.h>
#include <memory>
#include "dragnn/components/util/bulk_feature_extractor.h"
#include "dragnn/core/compute_session.h"
#include "dragnn/core/input_batch_cache.h"
#include "dragnn/protos/data.pb.h"
#include "dragnn/protos/spec.pb.h"
#include "syntaxnet/base.h"
#include <gmock/gmock.h>
#include "tensorflow/core/platform/test.h"
namespace syntaxnet {
......@@ -40,9 +42,9 @@ class MockComputeSession : public ComputeSession {
MOCK_METHOD2(SourceComponentBeamSize,
int(const string &component_name, int channel_id));
MOCK_METHOD1(AdvanceFromOracle, void(const string &component_name));
MOCK_METHOD3(AdvanceFromPrediction,
void(const string &component_name, const float score_matrix[],
int score_matrix_length));
MOCK_METHOD4(AdvanceFromPrediction,
bool(const string &component_name, const float *score_matrix,
int num_items, int num_actions));
MOCK_CONST_METHOD5(GetInputFeatures,
int(const string &component_name,
std::function<int32 *(int)> allocate_indices,
......@@ -52,6 +54,11 @@ class MockComputeSession : public ComputeSession {
MOCK_METHOD2(BulkGetInputFeatures,
int(const string &component_name,
const BulkFeatureExtractor &extractor));
MOCK_METHOD6(BulkEmbedFixedFeatures,
void(const string &component_name, int batch_size_padding,
int num_steps_padding, int output_array_size,
const vector<const float *> &per_channel_embedding,
float *embedding_output));
MOCK_METHOD2(GetTranslatedLinkFeatures,
std::vector<LinkFeatures>(const string &component_name,
int channel_id));
......@@ -68,9 +75,17 @@ class MockComputeSession : public ComputeSession {
MOCK_CONST_METHOD1(GetDescription, string(const string &component_name));
MOCK_CONST_METHOD1(Translators, const std::vector<const IndexTranslator *>(
const string &component_name));
MOCK_CONST_METHOD1(GetReadiedComponent, Component *(const string &name));
// TODO(googleuser): Upgrade gMock to a version that supports mocking methods
// with move-only types, then remove this workaround.
MOCK_METHOD1(DoSetInputBatchCache, void(InputBatchCache *batch));
void SetInputBatchCache(std::unique_ptr<InputBatchCache> batch) override {
DoSetInputBatchCache(batch.get());
}
};
} // namespace dragnn
} // namespace syntaxnet
#endif // NLP_SAFT_OPENSOURCE_DRAGNN_CORE_TEST_MOCK_COMPUTE_SESSION_H_
#endif // DRAGNN_CORE_TEST_MOCK_COMPUTE_SESSION_H_
......@@ -13,8 +13,8 @@
// limitations under the License.
// =============================================================================
#ifndef NLP_SAFT_OPENSOURCE_DRAGNN_CORE_TEST_MOCK_TRANSITION_STATE_H_
#define NLP_SAFT_OPENSOURCE_DRAGNN_CORE_TEST_MOCK_TRANSITION_STATE_H_
#ifndef DRAGNN_CORE_TEST_MOCK_TRANSITION_STATE_H_
#define DRAGNN_CORE_TEST_MOCK_TRANSITION_STATE_H_
#include <memory>
......@@ -31,15 +31,17 @@ class MockTransitionState : public TransitionState {
public:
MOCK_METHOD1(Init, void(const TransitionState &parent));
MOCK_CONST_METHOD0(Clone, std::unique_ptr<TransitionState>());
MOCK_CONST_METHOD0(ParentBeamIndex, const int());
MOCK_METHOD1(SetBeamIndex, void(const int index));
MOCK_CONST_METHOD0(GetBeamIndex, const int());
MOCK_CONST_METHOD0(GetScore, const float());
MOCK_METHOD1(SetScore, void(const float score));
MOCK_CONST_METHOD0(ParentBeamIndex, int());
MOCK_METHOD1(SetBeamIndex, void(int index));
MOCK_CONST_METHOD0(GetBeamIndex, int());
MOCK_CONST_METHOD0(GetScore, float());
MOCK_METHOD1(SetScore, void(float score));
MOCK_CONST_METHOD0(IsGold, bool());
MOCK_METHOD1(SetGold, void(bool is_gold));
MOCK_CONST_METHOD0(HTMLRepresentation, string());
};
} // namespace dragnn
} // namespace syntaxnet
#endif // NLP_SAFT_OPENSOURCE_DRAGNN_CORE_TEST_MOCK_TRANSITION_STATE_H_
#endif // DRAGNN_CORE_TEST_MOCK_TRANSITION_STATE_H_
......@@ -11,10 +11,6 @@ component {
key: "language"
value: "en"
}
parameters {
key: "neurosis_feature_syntax_version"
value: "2"
}
parameters {
key: "parser_skip_deterministic"
value: "false"
......
......@@ -13,8 +13,8 @@
// limitations under the License.
// =============================================================================
#ifndef NLP_SAFT_OPENSOURCE_DRAGNN_IO_SENTENCE_INPUT_BATCH_H_
#define NLP_SAFT_OPENSOURCE_DRAGNN_IO_SENTENCE_INPUT_BATCH_H_
#ifndef DRAGNN_IO_SENTENCE_INPUT_BATCH_H_
#define DRAGNN_IO_SENTENCE_INPUT_BATCH_H_
#include <string>
#include <vector>
......@@ -35,6 +35,9 @@ class SentenceInputBatch : public InputBatch {
void SetData(
const std::vector<string> &stringified_sentence_protos) override;
// Returns the size of the batch.
int GetSize() const override { return data_.size(); }
// Translates to a vector of stringified Sentence protos.
const std::vector<string> GetSerializedData() const override;
......@@ -49,4 +52,4 @@ class SentenceInputBatch : public InputBatch {
} // namespace dragnn
} // namespace syntaxnet
#endif // NLP_SAFT_OPENSOURCE_DRAGNN_IO_SENTENCE_INPUT_BATCH_H_
#endif // DRAGNN_IO_SENTENCE_INPUT_BATCH_H_
......@@ -49,6 +49,9 @@ TEST(SentenceInputBatchTest, ConvertsFromStringifiedProtos) {
EXPECT_NE(converted_data->at(i).workspace(), nullptr);
}
// Check the batch size.
EXPECT_EQ(strings.size(), set.GetSize());
// Get the data back out. The strings should be identical.
auto output = set.GetSerializedData();
EXPECT_EQ(output.size(), strings.size());
......
......@@ -13,8 +13,8 @@
// limitations under the License.
// =============================================================================
#ifndef NLP_SAFT_OPENSOURCE_DRAGNN_IO_SYNTAXNET_SENTENCE_H_
#define NLP_SAFT_OPENSOURCE_DRAGNN_IO_SYNTAXNET_SENTENCE_H_
#ifndef DRAGNN_IO_SYNTAXNET_SENTENCE_H_
#define DRAGNN_IO_SYNTAXNET_SENTENCE_H_
#include "syntaxnet/sentence.pb.h"
#include "syntaxnet/workspace.h"
......@@ -39,4 +39,4 @@ class SyntaxNetSentence {
} // namespace dragnn
} // namespace syntaxnet
#endif // NLP_SAFT_OPENSOURCE_DRAGNN_IO_SYNTAXNET_SENTENCE_H_
#endif // DRAGNN_IO_SYNTAXNET_SENTENCE_H_
......@@ -26,6 +26,12 @@ tf_proto_library(
srcs = ["spec.proto"],
)
tf_proto_library(
name = "runtime_proto",
srcs = ["runtime.proto"],
deps = [":spec_proto"],
)
tf_proto_library_py(
name = "data_py_pb2",
srcs = ["data.proto"],
......
syntax = "proto2";
import "dragnn/protos/spec.proto";
package syntaxnet.dragnn.runtime;
// Performance tuning settings that only affect resource usage, not annotated
// output or correctness. This should be attached to the MasterSpec used to
// initialize a Master.
//
// NEXT ID: 2
message MasterPerformanceSettings {
extend MasterSpec {
optional MasterPerformanceSettings master_spec_extension = 160848628;
}
// Maximum size of the free list in the SessionStatePool. NB: The default
// value may occasionally change.
optional uint64 session_state_pool_max_free_states = 1 [default = 4];
}
// As above, but for component-specific performance tuning settings.
//
// NEXT ID: 2
message ComponentPerformanceSettings {
extend ComponentSpec {
optional ComponentPerformanceSettings component_spec_extension = 160999422;
}
// Number of steps to pre-allocate for the relevant component. NB: The
// default value may occasionally change.
optional uint32 pre_allocate_num_steps = 1 [default = 50];
}
// Specification of an ArrayVariableStore.
//
// NEXT ID: 5
message ArrayVariableStoreSpec {
// Characteristics of the variable data. The binary that loads the variables
// must match these characteristics.
optional uint32 version = 1; // required version of the byte array format
optional uint32 alignment_bytes = 2; // required alignment of the byte array
optional bool is_little_endian = 3; // required endian-ness of the byte array
// Variable specifications, in order of appearance in the byte array.
repeated VariableSpec variable = 4;
}
// Specification of a single serialized variable.
//
// NEXT ID: 6
message VariableSpec {
// Formats for serialized pre-trained variables. See VariableStore::Lookup()
// for descriptions of the enumerators.
enum Format {
FORMAT_UNKNOWN = 0;
FORMAT_FLAT = 1;
FORMAT_ROW_MAJOR_MATRIX = 2;
FORMAT_COLUMN_BLOCKED_ROW_MAJOR_MATRIX = 3;
}
// Name of the variable.
optional string name = 1;
// Format of the variable.
optional Format format = 2 [default = FORMAT_UNKNOWN];
// Dimensions of variables. The semantics depends on the format, but is always
// in logical units (number of floats, etc.) rather than bytes,
//
// * flat: single value with the length of the vector
// * row-major and column-major: two values, [rows, columns]
// * row-blocked column-major: three values, [rows, columns, row_block_size]
repeated uint32 dimension = 5;
// Number of sub-views in the AlignedArea that contained the variable.
optional uint64 num_views = 3;
// Sub-view size in bytes for the AlignedArea that contained the variable.
optional uint64 view_size = 4;
}
......@@ -16,6 +16,7 @@ message MasterSpec {
// Whether to extract debug traces.
optional bool debug_tracing = 4 [default = false];
extensions 1000 to max;
reserved 2, 3, 5;
}
......@@ -28,8 +29,7 @@ message ComponentSpec {
// TransitionSystem to use.
optional RegisteredModuleSpec transition_system = 2;
// Resources that this component depends on. These are copied to TaskInputs
// when calling SAFT code.
// Resources that this component depends on.
repeated Resource resource = 3;
// Feature space configurations.
......@@ -58,6 +58,8 @@ message ComponentSpec {
// Default max number of active states for beam inference.
optional int32 inference_beam_size = 12 [default = 1];
extensions 1000 to max;
}
// Super generic container for any registered sub-piece of DRAGNN.
......@@ -65,14 +67,11 @@ message RegisteredModuleSpec {
// Name of the registered class.
optional string registered_name = 1;
// Parameters to set while initializing this system; these are copied to
// Parameters in a TaskSpec when calling SAFT code, or via kwargs in TF Python
// code.
// Parameters to set while initializing this system.
map<string, string> parameters = 2;
}
// Fixed resources that will be converted into TaskInput's when calling SAFT
// code.
// Fixed resource.
message Resource {
optional string name = 1;
repeated Part part = 2;
......@@ -218,6 +217,9 @@ message GridPoint {
optional double gradient_clip_norm = 11 [default = 0.0];
// A spec for using multiple optimization methods.
//
// This is not guaranteed to work for recursively-defined composite
// optimizers.
message CompositeOptimizerSpec {
// First optimizer.
optional GridPoint method1 = 1;
......@@ -227,6 +229,11 @@ message GridPoint {
// After this number of steps, switch from first to second.
optional int32 switch_after_steps = 3;
// Whether to reset the learning rate (which normally decays) after
// switching optimizers. Limitations: It will only reset to the initial
// learning rate, and won't work for recursively-defined optimizers.
optional bool reset_learning_rate = 4 [default = false];
}
optional CompositeOptimizerSpec composite_optimizer_spec = 12;
......@@ -247,6 +254,7 @@ message GridPoint {
// place. Typically a single component.
optional string self_norm_components_filter = 21;
extensions 1000 to max;
reserved 5, 6;
}
......
......@@ -16,6 +16,11 @@ cc_binary(
],
)
filegroup(
name = "testdata",
data = glob(["testdata/**"]),
)
py_library(
name = "load_dragnn_cc_impl_py",
srcs = ["load_dragnn_cc_impl.py"],
......@@ -64,7 +69,51 @@ py_library(
py_library(
name = "dragnn_ops",
srcs = ["dragnn_ops.py"],
deps = [],
deps = [
":load_dragnn_cc_impl_py",
"//dragnn/core:dragnn_bulk_ops",
"//dragnn/core:dragnn_ops",
"//syntaxnet:load_parser_ops_py",
],
)
py_library(
name = "dragnn_model_saver_lib",
srcs = ["dragnn_model_saver_lib.py"],
deps = [
":dragnn_ops",
":graph_builder",
":load_dragnn_cc_impl_py",
":network_units",
"//dragnn/protos:spec_py_pb2",
"//syntaxnet:load_parser_ops_py",
"//syntaxnet:sentence_py_pb2",
"@org_tensorflow//tensorflow:tensorflow_py",
"@org_tensorflow//tensorflow/core:protos_all_py",
],
)
py_test(
name = "dragnn_model_saver_lib_test",
srcs = ["dragnn_model_saver_lib_test.py"],
data = [":testdata"],
deps = [
":dragnn_model_saver_lib",
"//dragnn/protos:spec_py_pb2",
"@org_tensorflow//tensorflow:tensorflow_py",
],
)
py_binary(
name = "dragnn_model_saver",
srcs = ["dragnn_model_saver.py"],
deps = [
":dragnn_model_saver_lib",
":spec_builder",
"//dragnn/protos:spec_py_pb2",
"@org_tensorflow//tensorflow:tensorflow_py",
"@org_tensorflow//tensorflow/core:protos_all_py",
],
)
py_library(
......@@ -76,6 +125,7 @@ py_library(
":composite_optimizer",
":dragnn_ops",
":network_units",
":transformer_units",
":wrapped_units",
"//dragnn/protos:spec_py_pb2",
"//syntaxnet/util:check",
......@@ -184,10 +234,7 @@ py_test(
":bulk_component",
":components",
":dragnn_ops",
":load_dragnn_cc_impl_py",
":network_units",
"//dragnn/core:dragnn_bulk_ops",
"//dragnn/core:dragnn_ops",
"//dragnn/protos:spec_py_pb2",
"//syntaxnet:load_parser_ops_py",
"//syntaxnet:sentence_py_pb2",
......@@ -201,7 +248,6 @@ py_test(
srcs = ["composite_optimizer_test.py"],
deps = [
":composite_optimizer",
":load_dragnn_cc_impl_py",
"//dragnn/core:dragnn_bulk_ops",
"//dragnn/core:dragnn_ops",
"//syntaxnet:load_parser_ops_py",
......@@ -217,15 +263,13 @@ py_test(
data = [
"//dragnn/core:testdata",
],
shard_count = 5,
tags = [
"notsan",
],
deps = [
":dragnn_ops",
":graph_builder",
":load_dragnn_cc_impl_py",
"//dragnn/core:dragnn_bulk_ops",
"//dragnn/core:dragnn_ops",
"//dragnn/protos:spec_py_pb2",
"//dragnn/protos:trace_py_pb2",
"//syntaxnet:load_parser_ops_py",
......@@ -240,7 +284,6 @@ py_test(
size = "small",
srcs = ["network_units_test.py"],
deps = [
":load_dragnn_cc_impl_py",
":network_units",
"//dragnn/core:dragnn_bulk_ops",
"//dragnn/core:dragnn_ops",
......@@ -256,6 +299,7 @@ py_test(
srcs = ["sentence_io_test.py"],
data = ["//syntaxnet:testdata"],
deps = [
":dragnn_ops",
":sentence_io",
"//syntaxnet:load_parser_ops_py",
"//syntaxnet:parser_ops",
......@@ -373,3 +417,30 @@ py_library(
"@org_tensorflow//tensorflow:tensorflow_py",
],
)
py_library(
name = "transformer_units",
srcs = ["transformer_units.py"],
deps = [
":network_units",
"//syntaxnet/util:check",
"@org_tensorflow//tensorflow:tensorflow_py",
],
)
py_test(
name = "transformer_units_test",
size = "small",
srcs = ["transformer_units_test.py"],
deps = [
":network_units",
":transformer_units",
"//dragnn/core:dragnn_bulk_ops",
"//dragnn/core:dragnn_ops",
"//dragnn/protos:spec_py_pb2",
"//syntaxnet:load_parser_ops_py",
"@org_tensorflow//tensorflow:tensorflow_py",
"@org_tensorflow//tensorflow/core:protos_all_py",
],
)
......@@ -95,7 +95,7 @@ class BiaffineDigraphNetwork(network_units.NetworkUnitInterface):
self._regularized_weights.extend(self._weights)
# Negative Layer.dim indicates that the dimension is dynamic.
self._layers.append(network_units.Layer(self, 'adjacency', -1))
self._layers.append(network_units.Layer(component, 'adjacency', -1))
def create(self,
fixed_embeddings,
......@@ -209,7 +209,8 @@ class BiaffineLabelNetwork(network_units.NetworkUnitInterface):
self._params.extend(self._weights + self._biases)
self._regularized_weights.extend(self._weights)
self._layers.append(network_units.Layer(self, 'labels', self._num_labels))
self._layers.append(
network_units.Layer(component, 'labels', self._num_labels))
def create(self,
fixed_embeddings,
......
......@@ -216,9 +216,11 @@ def build_cross_entropy_loss(logits, gold):
logits = tf.gather(logits, valid)
correct = tf.reduce_sum(tf.to_int32(tf.nn.in_top_k(logits, gold, 1)))
total = tf.size(gold)
cost = tf.reduce_sum(
tf.contrib.nn.deprecated_flipped_sparse_softmax_cross_entropy_with_logits(
logits, tf.cast(gold, tf.int64))) / tf.cast(total, tf.float32)
with tf.control_dependencies([tf.assert_positive(total)]):
cost = tf.reduce_sum(
tf.nn.sparse_softmax_cross_entropy_with_logits(
labels=tf.cast(gold, tf.int64), logits=logits)) / tf.cast(
total, tf.float32)
return cost, correct, total
......@@ -267,6 +269,22 @@ class BulkFeatureExtractorComponentBuilder(component.ComponentBuilderBase):
correct, total = tf.constant(0), tf.constant(0)
return state.handle, cost, correct, total
def build_post_restore_hook(self):
"""Builds a graph that should be executed after the restore op.
This graph is intended to be run once, before the inference pipeline is
run.
Returns:
setup_op - An op that, when run, guarantees all setup ops will run.
"""
logging.info('Building restore hook for component: %s', self.spec.name)
with tf.variable_scope(self.name):
if callable(getattr(self.network, 'build_post_restore_hook', None)):
return [self.network.build_post_restore_hook()]
else:
return []
def build_greedy_inference(self, state, network_states,
during_training=False):
"""Extracts features and advances a batch using the oracle path.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment