Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
4364390a
Commit
4364390a
authored
Nov 13, 2017
by
Ivan Bogatyy
Committed by
calberti
Nov 13, 2017
Browse files
Release DRAGNN bulk networks (#2785)
* Release DRAGNN bulk networks
parent
638fd759
Changes
166
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
818 additions
and
94 deletions
+818
-94
research/syntaxnet/dragnn/core/ops/dragnn_bulk_ops.cc
research/syntaxnet/dragnn/core/ops/dragnn_bulk_ops.cc
+32
-0
research/syntaxnet/dragnn/core/ops/dragnn_op_kernels.cc
research/syntaxnet/dragnn/core/ops/dragnn_op_kernels.cc
+165
-11
research/syntaxnet/dragnn/core/ops/dragnn_op_kernels_test.cc
research/syntaxnet/dragnn/core/ops/dragnn_op_kernels_test.cc
+320
-4
research/syntaxnet/dragnn/core/ops/dragnn_ops.cc
research/syntaxnet/dragnn/core/ops/dragnn_ops.cc
+26
-22
research/syntaxnet/dragnn/core/resource_container.h
research/syntaxnet/dragnn/core/resource_container.h
+3
-3
research/syntaxnet/dragnn/core/test/BUILD
research/syntaxnet/dragnn/core/test/BUILD
+1
-0
research/syntaxnet/dragnn/core/test/generic.h
research/syntaxnet/dragnn/core/test/generic.h
+11
-3
research/syntaxnet/dragnn/core/test/mock_component.h
research/syntaxnet/dragnn/core/test/mock_component.h
+10
-5
research/syntaxnet/dragnn/core/test/mock_compute_session.h
research/syntaxnet/dragnn/core/test/mock_compute_session.h
+22
-7
research/syntaxnet/dragnn/core/test/mock_transition_state.h
research/syntaxnet/dragnn/core/test/mock_transition_state.h
+10
-8
research/syntaxnet/dragnn/core/testdata/master_spec_link.textproto
...syntaxnet/dragnn/core/testdata/master_spec_link.textproto
+0
-4
research/syntaxnet/dragnn/io/sentence_input_batch.h
research/syntaxnet/dragnn/io/sentence_input_batch.h
+6
-3
research/syntaxnet/dragnn/io/sentence_input_batch_test.cc
research/syntaxnet/dragnn/io/sentence_input_batch_test.cc
+3
-0
research/syntaxnet/dragnn/io/syntaxnet_sentence.h
research/syntaxnet/dragnn/io/syntaxnet_sentence.h
+3
-3
research/syntaxnet/dragnn/protos/BUILD
research/syntaxnet/dragnn/protos/BUILD
+6
-0
research/syntaxnet/dragnn/protos/runtime.proto
research/syntaxnet/dragnn/protos/runtime.proto
+81
-0
research/syntaxnet/dragnn/protos/spec.proto
research/syntaxnet/dragnn/protos/spec.proto
+15
-7
research/syntaxnet/dragnn/python/BUILD
research/syntaxnet/dragnn/python/BUILD
+80
-9
research/syntaxnet/dragnn/python/biaffine_units.py
research/syntaxnet/dragnn/python/biaffine_units.py
+3
-2
research/syntaxnet/dragnn/python/bulk_component.py
research/syntaxnet/dragnn/python/bulk_component.py
+21
-3
No files found.
research/syntaxnet/dragnn/core/ops/dragnn_bulk_ops.cc
View file @
4364390a
...
...
@@ -80,6 +80,38 @@ pad_to_batch: If set, the op will pad/truncate to this number of elements.
pad_to_steps: If set, the op will pad/truncate to this number of steps.
)doc"
);
REGISTER_OP
(
"BulkEmbedFixedFeatures"
)
.
Input
(
"handle: string"
)
.
Input
(
"embedding_matrix: num_channels * float"
)
.
Output
(
"output_handle: string"
)
.
Output
(
"embedding_vectors: float"
)
.
Output
(
"num_steps: int32"
)
.
Attr
(
"component: string"
)
.
Attr
(
"num_channels: int"
)
.
Attr
(
"pad_to_batch: int"
)
.
Attr
(
"pad_to_steps: int"
)
.
SetIsStateful
()
.
Doc
(
R"doc(
This op is a more efficient version of BulkFixedFeatures.
It is intended to be run with large batch sizes at inference time. The op takes
a handle to ComputeSession and embedding matrices as tensor inputs, and directly
outputs concatenated embedding vectors. It calls the BulkEmbedFixedFeatures
method on the underlying component directly, so it requires a padding vector
to be passed.
handle: A handle to ComputeSession.
embedding_matrix: Embedding matrices.
output_handle: A handle to the same ComputeSession after advancement.
embedding_vectors: (matrix of float) Concatenated embeddings,
shaped as (batch * beam * token) x sum_channel(embedding_dim[channel]).
num_steps: The batch was unrolled for these many steps.
component: The name of a Component instance, matching the ComponentSpec.name.
num_channels: The number of FixedFeature channels.
pad_to_batch: The op will pad/truncate to this number of elements.
pad_to_steps: The op will pad/truncate to this number of steps.
)doc"
);
REGISTER_OP
(
"BulkAdvanceFromOracle"
)
.
Input
(
"handle: string"
)
.
Output
(
"output_handle: string"
)
...
...
research/syntaxnet/dragnn/core/ops/dragnn_op_kernels.cc
View file @
4364390a
...
...
@@ -30,6 +30,7 @@
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/threadpool.h"
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/mutex.h"
...
...
@@ -40,6 +41,8 @@ using tensorflow::DT_INT32;
using
tensorflow
::
DT_INT64
;
using
tensorflow
::
DT_STRING
;
using
tensorflow
::
DataType
;
using
tensorflow
::
io
::
Dirname
;
using
tensorflow
::
io
::
JoinPath
;
using
tensorflow
::
OpKernel
;
using
tensorflow
::
OpKernelConstruction
;
using
tensorflow
::
OpKernelContext
;
...
...
@@ -53,6 +56,59 @@ namespace dragnn {
typedef
ResourceContainer
<
ComputeSession
>
ComputeSessionResource
;
typedef
ResourceContainer
<
ComputeSessionPool
>
ComputeSessionPoolResource
;
typedef
ResourceContainer
<
string
>
StringResource
;
namespace
{
const
char
kGlobalContainer
[]
=
"__reserved_global_container"
;
const
char
kBasePathTag
[]
=
"__reserved_asset_base_path"
;
const
char
kUnmanagedAssetDirectory
[]
=
"assets.extra"
;
// When restoring a graph from a SavedModel, this op will rewrite the MasterSpec
// to point the DRAGNN components to the new resource locations. It will then
// add a string resource to the resource manager, which will be used to
// rebuild the masterspec before it is acquired in the GetComputeSession op.
class
SetAssetDirectory
:
public
OpKernel
{
public:
explicit
SetAssetDirectory
(
OpKernelConstruction
*
context
)
:
OpKernel
(
context
)
{
OP_REQUIRES_OK
(
context
,
context
->
MatchSignature
({
DT_STRING
},
{
DT_STRING
}));
}
void
Compute
(
OpKernelContext
*
context
)
override
{
ResourceMgr
*
rmgr
=
context
->
resource_manager
();
const
string
asset_path
=
context
->
input
(
0
).
scalar
<
string
>
()();
// TODO(googleuser): Get this data in a way that isn't fragile as all hell.
// "I've done stuff I ain't proud of... and the stuff I am proud of is
// disgusting." -- Moe
auto
extra_asset_dir
=
JoinPath
(
Dirname
(
Dirname
(
asset_path
)),
kUnmanagedAssetDirectory
);
LOG
(
INFO
)
<<
"Found extra assets path at:"
<<
extra_asset_dir
;
// Rather than attempt to rewrite the MasterSpec here, we save off a
// StringResource containing the new asset path. It will be used in
// the GetSession op, if it exists.
std
::
unique_ptr
<
string
>
asset_path_ptr
(
new
string
(
extra_asset_dir
));
OP_REQUIRES_OK
(
context
,
rmgr
->
Create
<
StringResource
>
(
kGlobalContainer
,
kBasePathTag
,
new
StringResource
(
std
::
move
(
asset_path_ptr
))));
// This isn't used anywhere - it just allows us to have an output so that
// it's easier to reason about Tensorflow's graph execution.
Tensor
*
output
;
OP_REQUIRES_OK
(
context
,
context
->
allocate_output
(
0
,
TensorShape
({
1
}),
&
output
));
output
->
vec
<
string
>
()(
0
)
=
asset_path
;
}
private:
TF_DISALLOW_COPY_AND_ASSIGN
(
SetAssetDirectory
);
};
REGISTER_KERNEL_BUILDER
(
Name
(
"SetAssetDirectory"
).
Device
(
DEVICE_CPU
),
SetAssetDirectory
);
// Given a MasterSpec proto, outputs a handle to a ComputeSession.
class
GetSession
:
public
OpKernel
{
...
...
@@ -66,6 +122,7 @@ class GetSession : public OpKernel {
CHECK
(
master_spec_
.
ParseFromString
(
master_spec_str
));
CHECK
(
grid_point_
.
ParseFromString
(
grid_point_spec_str
));
OP_REQUIRES_OK
(
context
,
context
->
MatchSignature
({
DT_STRING
},
{
DT_STRING
}));
has_overwritten_spec_
=
false
;
}
void
Compute
(
OpKernelContext
*
context
)
override
{
...
...
@@ -74,10 +131,32 @@ class GetSession : public OpKernel {
// Create the pool for this container, or re-use one that was allocated in a
// previous call.
auto
create_pool
=
[
this
,
auto
create_pool
=
[
this
,
&
rmgr
,
&
container
](
ComputeSessionPoolResource
**
resource
)
{
LOG
(
INFO
)
<<
"Creating new ComputeSessionPool in container handle: "
<<
container
;
if
(
has_overwritten_spec_
)
{
// TODO(googleuser): Figure out a way to test this.
// If there's already an overwritten spec, use that.
LOG
(
INFO
)
<<
"Creating new ComputeSessionPool in container handle: "
<<
container
<<
" with previously overwritten master spec."
;
}
else
{
// If not, try to find the resource base.
StringResource
*
resource_base
;
auto
resource_base_lookup
=
rmgr
->
Lookup
<
StringResource
>
(
kGlobalContainer
,
kBasePathTag
,
&
resource_base
);
if
(
resource_base_lookup
.
ok
())
{
// If that exists, the spec must be rewritten.
string
resource_base_path
=
*
resource_base
->
get
();
LOG
(
INFO
)
<<
"Creating new ComputeSessionPool in container handle: "
<<
container
<<
" using resource directory base "
<<
resource_base_path
;
RewriteMasterSpec
(
resource_base_path
);
resource_base
->
Unref
();
}
else
{
// If not, just use the spec as is.
LOG
(
INFO
)
<<
"Creating new ComputeSessionPool in container handle: "
<<
container
<<
" without editing master spec."
;
}
}
std
::
unique_ptr
<
ComputeSessionPool
>
pool
(
new
ComputeSessionPool
(
master_spec_
,
grid_point_
));
*
resource
=
new
ComputeSessionPoolResource
(
std
::
move
(
pool
));
...
...
@@ -120,6 +199,23 @@ class GetSession : public OpKernel {
}
private:
// Rewrites this op's saved MasterSpec, appending the new base directory.
void
RewriteMasterSpec
(
const
string
&
new_base
)
{
for
(
auto
&
component_spec
:
*
master_spec_
.
mutable_component
())
{
for
(
auto
&
resource_def
:
*
component_spec
.
mutable_resource
())
{
for
(
auto
&
part_def
:
*
resource_def
.
mutable_part
())
{
part_def
.
set_file_pattern
(
JoinPath
(
new_base
,
part_def
.
file_pattern
()));
VLOG
(
2
)
<<
"New path: "
<<
part_def
.
file_pattern
();
}
}
}
VLOG
(
3
)
<<
"Rewritten spec: "
<<
master_spec_
.
DebugString
();
has_overwritten_spec_
=
true
;
}
bool
has_overwritten_spec_
;
MasterSpec
master_spec_
;
GridPoint
grid_point_
;
...
...
@@ -141,7 +237,6 @@ REGISTER_KERNEL_BUILDER(Name("GetSession").Device(DEVICE_CPU), GetSession);
class
ReleaseSession
:
public
OpKernel
{
public:
explicit
ReleaseSession
(
OpKernelConstruction
*
context
)
:
OpKernel
(
context
)
{
string
master_spec_str
;
OP_REQUIRES_OK
(
context
,
context
->
MatchSignature
({
DT_STRING
},
{}));
}
...
...
@@ -188,6 +283,53 @@ class ReleaseSession : public OpKernel {
REGISTER_KERNEL_BUILDER
(
Name
(
"ReleaseSession"
).
Device
(
DEVICE_CPU
),
ReleaseSession
);
// Returns statistics about session loads to the graph. This op returns the
// total number of created Session objects and the number of those objects
// that are currently being used in the ComputeSessionPool.
class
GetSessionCounts
:
public
OpKernel
{
public:
explicit
GetSessionCounts
(
OpKernelConstruction
*
context
)
:
OpKernel
(
context
)
{
OP_REQUIRES_OK
(
context
,
context
->
MatchSignature
({
DT_STRING
},
{
DT_INT64
}));
}
void
Compute
(
OpKernelContext
*
context
)
override
{
const
string
container
=
context
->
input
(
0
).
scalar
<
string
>
()();
VLOG
(
1
)
<<
"Getting stats for container: "
<<
container
;
ResourceMgr
*
rmgr
=
context
->
resource_manager
();
// Allocate the output tensors.
Tensor
*
output
;
OP_REQUIRES_OK
(
context
,
context
->
allocate_output
(
0
,
TensorShape
({
2
}),
&
output
));
// Get the pool for this container.
ComputeSessionPoolResource
*
pool_resource
;
auto
result
=
rmgr
->
Lookup
<
ComputeSessionPoolResource
>
(
container
,
"pool"
,
&
pool_resource
);
if
(
!
result
.
ok
())
{
// If there's no ComputeSessionPoolResource, report 0 sessions created
// and 0 available.
output
->
vec
<
int64
>
()(
0
)
=
0
;
output
->
vec
<
int64
>
()(
1
)
=
0
;
return
;
}
auto
*
pool
=
pool_resource
->
get
();
CHECK
(
pool
!=
nullptr
);
output
->
vec
<
int64
>
()(
0
)
=
pool
->
num_unique_sessions
();
output
->
vec
<
int64
>
()(
1
)
=
pool
->
num_outstanding_sessions
();
pool_resource
->
Unref
();
}
private:
TF_DISALLOW_COPY_AND_ASSIGN
(
GetSessionCounts
);
};
REGISTER_KERNEL_BUILDER
(
Name
(
"GetSessionCounts"
).
Device
(
DEVICE_CPU
),
GetSessionCounts
);
/*******************************************************************************
* ComputeSessionOps below here.
******************************************************************************/
...
...
@@ -233,9 +375,17 @@ class AdvanceFromPrediction : public ComputeSessionOp {
void
ComputeWithState
(
OpKernelContext
*
context
,
ComputeSession
*
session
)
override
{
const
Tensor
&
scores
=
context
->
input
(
1
);
session
->
AdvanceFromPrediction
(
component_name
(),
scores
.
tensor
<
float
,
2
>
().
data
(),
scores
.
NumElements
());
const
int
num_items
=
scores
.
shape
().
dim_size
(
0
);
const
int
num_actions
=
scores
.
shape
().
dim_size
(
1
);
bool
success
=
session
->
AdvanceFromPrediction
(
component_name
(),
scores
.
tensor
<
float
,
2
>
().
data
(),
num_items
,
num_actions
);
if
(
success
)
{
VLOG
(
2
)
<<
"Score: "
<<
scores
.
tensor
<
float
,
2
>
();
}
OP_REQUIRES
(
context
,
success
,
tensorflow
::
errors
::
Internal
(
"Unable to advance from prediction."
));
}
private:
...
...
@@ -247,13 +397,12 @@ REGISTER_KERNEL_BUILDER(Name("AdvanceFromPrediction").Device(DEVICE_CPU),
// Given a handle to a ComputeSession and a channel index, outputs fixed
// features.
// Fixed features are returned as 3 vectors o
r
equal length:
// Fixed features are returned as 3 vectors o
f
equal length:
// - ids: specifies which rows should be looked up in the embedding
// matrix,
// - weights: specifies a scale for each embedding vector,
// - indices: sorted vector that assigns the same index to embedding
// vectors
// that should be summed together.
// vectors that should be summed together.
//
// For example if we have 3 features, for a given channel, we might have:
// feature a: (5, 1)
...
...
@@ -300,7 +449,10 @@ class ExtractFixedFeatures : public ComputeSessionOp {
int
num_features
=
session
->
GetInputFeatures
(
component_name
(),
indices_allocator
,
ids_allocator
,
weights_allocator
,
channel_id_
);
VLOG
(
2
)
<<
"Extracted "
<<
num_features
;
VLOG
(
2
)
<<
"Extracted features ("
<<
num_features
<<
"): "
<<
" ids="
<<
context
->
mutable_output
(
1
)
->
vec
<
int64
>
()
<<
" weights="
<<
context
->
mutable_output
(
2
)
->
vec
<
float
>
()
<<
" indices="
<<
context
->
mutable_output
(
0
)
->
vec
<
int32
>
();
}
private:
...
...
@@ -524,6 +676,7 @@ class AttachDataReader : public ComputeSessionOp {
auto
input_data
(
context
->
input
(
1
).
vec
<
string
>
());
std
::
vector
<
string
>
data
;
data
.
reserve
(
input_data
.
size
());
for
(
int
i
=
0
;
i
<
input_data
.
size
();
++
i
)
{
data
.
push_back
(
input_data
(
i
));
}
...
...
@@ -642,5 +795,6 @@ class GetComponentTrace : public ComputeSessionOp {
REGISTER_KERNEL_BUILDER
(
Name
(
"GetComponentTrace"
).
Device
(
DEVICE_CPU
),
GetComponentTrace
);
}
// namespace
}
// namespace dragnn
}
// namespace syntaxnet
research/syntaxnet/dragnn/core/ops/dragnn_op_kernels_test.cc
View file @
4364390a
...
...
@@ -17,6 +17,7 @@
#include <memory>
#include <vector>
#include "dragnn/core/component_registry.h"
#include "dragnn/core/compute_session.h"
#include "dragnn/core/compute_session_pool.h"
#include "dragnn/core/resource_container.h"
...
...
@@ -66,6 +67,87 @@ using testing::Return;
typedef
ResourceContainer
<
ComputeSession
>
ComputeSessionResource
;
typedef
ResourceContainer
<
ComputeSessionPool
>
ComputeSessionPoolResource
;
typedef
ResourceContainer
<
string
>
StringResource
;
namespace
{
const
char
kGlobalContainer
[]
=
"__reserved_global_container"
;
const
char
kBasePathTag
[]
=
"__reserved_asset_base_path"
;
const
char
kUnmanagedAssetDirectory
[]
=
"assets.extra"
;
}
// namespace
// Define a test component to validate registered construction.
class
TestComponent
:
public
Component
{
public:
TestComponent
()
{}
void
InitializeComponent
(
const
ComponentSpec
&
spec
)
override
{
name_
=
spec
.
name
();
}
void
InitializeData
(
const
std
::
vector
<
std
::
vector
<
const
TransitionState
*>>
&
states
,
int
max_beam_size
,
InputBatchCache
*
input_data
)
override
{}
void
InitializeTracing
()
override
{}
void
DisableTracing
()
override
{}
bool
IsReady
()
const
override
{
return
true
;
}
string
Name
()
const
override
{
return
name_
;
}
int
BeamSize
()
const
override
{
return
3
;
}
int
BatchSize
()
const
override
{
return
1
;
}
int
StepsTaken
(
int
batch_index
)
const
override
{
return
0
;
}
int
GetBeamIndexAtStep
(
int
step
,
int
current_index
,
int
batch
)
const
override
{
return
0
;
}
int
GetSourceBeamIndex
(
int
current_index
,
int
batch
)
const
override
{
return
0
;
}
bool
AdvanceFromPrediction
(
const
float
*
score_matrix
,
int
num_items
,
int
num_actions
)
override
{
return
true
;
}
void
AdvanceFromOracle
()
override
{}
bool
IsTerminal
()
const
override
{
return
true
;
}
std
::
function
<
int
(
int
,
int
,
int
)
>
GetStepLookupFunction
(
const
string
&
method
)
override
{
return
nullptr
;
}
std
::
vector
<
std
::
vector
<
const
TransitionState
*>>
GetBeam
()
override
{
std
::
vector
<
std
::
vector
<
const
TransitionState
*>>
states
;
return
states
;
}
int
GetFixedFeatures
(
std
::
function
<
int32
*
(
int
)
>
allocate_indices
,
std
::
function
<
int64
*
(
int
)
>
allocate_ids
,
std
::
function
<
float
*
(
int
)
>
allocate_weights
,
int
channel_id
)
const
override
{
return
0
;
}
int
BulkGetFixedFeatures
(
const
BulkFeatureExtractor
&
extractor
)
override
{
return
0
;
}
void
BulkEmbedFixedFeatures
(
int
batch_size_padding
,
int
num_steps_padding
,
int
output_array_size
,
const
vector
<
const
float
*>
&
per_channel_embeddings
,
float
*
embedding_matrix
)
override
{}
std
::
vector
<
LinkFeatures
>
GetRawLinkFeatures
(
int
channel_id
)
const
override
{
std
::
vector
<
LinkFeatures
>
ret
;
return
ret
;
}
std
::
vector
<
std
::
vector
<
int
>>
GetOracleLabels
()
const
override
{
std
::
vector
<
std
::
vector
<
int
>>
ret
;
return
ret
;
}
void
FinalizeData
()
override
{}
void
ResetComponent
()
override
{}
std
::
vector
<
std
::
vector
<
ComponentTrace
>>
GetTraceProtos
()
const
override
{
std
::
vector
<
std
::
vector
<
ComponentTrace
>>
ret
;
return
ret
;
}
void
AddTranslatedLinkFeaturesToTrace
(
const
std
::
vector
<
LinkFeatures
>
&
features
,
int
channel_id
)
override
{}
string
name_
;
};
REGISTER_DRAGNN_COMPONENT
(
TestComponent
);
class
DragnnOpKernelsTest
:
public
tensorflow
::
OpsTestBase
{
public:
...
...
@@ -106,6 +188,42 @@ LinkFeatures MakeFeatures(int batch_index, int beam_index, int step) {
return
features
;
}
// The SetAssetDirectory op should
// 1. When given an asset path (foo/bar/baz/asset/thing), strip the path to
// foo/bar/baz and add 'assets.extra' to it.
// 2. Store that path in the resource manager.
TEST_F
(
DragnnOpKernelsTest
,
SetAssetDirectoryTest
)
{
// Create a MasterSpec and GridPoint string to pass into the attrs for this
// op.
const
string
new_asset_path
=
"new/directory/path/asset/master_spec"
;
const
string
expected_asset_path
=
StrCat
(
"new/directory/path/"
,
kUnmanagedAssetDirectory
);
// Create and initialize the kernel under test.
TF_ASSERT_OK
(
NodeDefBuilder
(
"set_asset_directory"
,
"SetAssetDirectory"
)
.
Input
(
FakeInput
(
DT_STRING
))
// The new asset path.
.
Finalize
(
node_def
()));
TF_ASSERT_OK
(
InitOp
());
// Set the input data.
AddInputFromList
<
string
>
(
TensorShape
({
1
}),
{
new_asset_path
});
// Reset the test context to ensure it's clean.
ResetOpKernelContext
();
// Run the kernel.
TF_EXPECT_OK
(
RunOpKernelWithContext
());
// Expect that the ResourceMgr contains a the correct string.
StringResource
*
resource
;
TF_EXPECT_OK
(
resource_mgr
()
->
Lookup
<
StringResource
>
(
kGlobalContainer
,
kBasePathTag
,
&
resource
));
EXPECT_EQ
(
*
resource
->
get
(),
expected_asset_path
);
resource
->
Unref
();
}
// The GetSessionOp should
// 1. create a ComputeSessionPool resource and store it in the ResourceMgr,
// 2. create a ComputeSession resource and store it in the ResourceMgr,
...
...
@@ -164,6 +282,103 @@ TEST_F(DragnnOpKernelsTest, GetSessionOpTest) {
pool_resource
->
Unref
();
}
// If an asset_base_path resource exists, the GetSession op should prepend
// that path to all paths in the MasterSpec before creating a session.
TEST_F
(
DragnnOpKernelsTest
,
GetSessionWithAssetBasePathTest
)
{
// Create a MasterSpec and GridPoint string to pass into the attrs for this
// op.
const
string
new_asset_path
=
"new/base"
;
MasterSpec
spec
;
// The first component in the MasterSpec has one resource with one part.
auto
component_one
=
spec
.
add_component
();
auto
backend_one
=
component_one
->
mutable_backend
();
backend_one
->
set_registered_name
(
"TestComponent"
);
component_one
->
add_resource
()
->
add_part
()
->
set_file_pattern
(
"path/to/an/asset.txt"
);
const
string
expected_component_one_asset
=
"new/base/path/to/an/asset.txt"
;
auto
component_two
=
spec
.
add_component
();
auto
backend_two
=
component_two
->
mutable_backend
();
backend_two
->
set_registered_name
(
"TestComponent"
);
// The second component's first resource has no assets.
component_two
->
add_resource
();
// The second component's second resource has one part.
vector
<
string
>
expected_component_two_assets
;
component_two
->
add_resource
()
->
add_part
()
->
set_file_pattern
(
"another/dir/with/an/asset.txt"
);
expected_component_two_assets
.
push_back
(
"new/base/another/dir/with/an/asset.txt"
);
// The second component's third resource has two parts.
auto
third_resource
=
component_two
->
add_resource
();
third_resource
->
add_part
()
->
set_file_pattern
(
"another/dir/with/an/asset3.jif"
);
expected_component_two_assets
.
push_back
(
"new/base/another/dir/with/an/asset3.jif"
);
third_resource
->
add_part
()
->
set_file_pattern
(
"another/dir/with/an/asset4.jif"
);
expected_component_two_assets
.
push_back
(
"new/base/another/dir/with/an/asset4.jif"
);
LOG
(
INFO
)
<<
spec
.
DebugString
();
string
master_spec_str
;
spec
.
SerializeToString
(
&
master_spec_str
);
GridPoint
hyperparams
;
string
hyperparams_str
;
hyperparams
.
SerializeToString
(
&
hyperparams_str
);
// Create and initialize the kernel under test.
TF_ASSERT_OK
(
NodeDefBuilder
(
"get_session"
,
"GetSession"
)
.
Attr
(
"master_spec"
,
master_spec_str
)
.
Attr
(
"grid_point"
,
hyperparams_str
)
.
Input
(
FakeInput
(
DT_STRING
))
// The handle for the ComputeSession.
.
Finalize
(
node_def
()));
TF_ASSERT_OK
(
InitOp
());
// Set the input data.
const
string
container_string
=
"container_str"
;
AddInputFromList
<
string
>
(
TensorShape
({
1
}),
{
container_string
});
// Reset the test context to ensure it's clean.
ResetOpKernelContext
();
// Create the string in the resource manager.
std
::
unique_ptr
<
string
>
asset_path_ptr
(
new
string
(
new_asset_path
));
TF_EXPECT_OK
(
resource_mgr
()
->
Create
<
StringResource
>
(
kGlobalContainer
,
kBasePathTag
,
new
StringResource
(
std
::
move
(
asset_path_ptr
))));
// Run the kernel.
TF_EXPECT_OK
(
RunOpKernelWithContext
());
// Expect that the ResourceMgr contains a ComputeSessionPoolResource.
const
string
pool_id_str
=
"pool"
;
ComputeSessionPoolResource
*
pool_resource
;
TF_EXPECT_OK
(
resource_mgr
()
->
Lookup
<
ComputeSessionPoolResource
>
(
container_string
,
pool_id_str
,
&
pool_resource
));
// Validate that the master spec held by the pool has the new directory names.
auto
rewritten_spec
=
pool_resource
->
get
()
->
GetSpec
();
EXPECT_EQ
(
rewritten_spec
.
component
(
0
).
resource
(
0
).
part
(
0
).
file_pattern
(),
expected_component_one_asset
);
EXPECT_EQ
(
rewritten_spec
.
component
(
1
).
resource
(
1
).
part
(
0
).
file_pattern
(),
expected_component_two_assets
.
at
(
0
));
EXPECT_EQ
(
rewritten_spec
.
component
(
1
).
resource
(
2
).
part
(
0
).
file_pattern
(),
expected_component_two_assets
.
at
(
1
));
EXPECT_EQ
(
rewritten_spec
.
component
(
1
).
resource
(
2
).
part
(
1
).
file_pattern
(),
expected_component_two_assets
.
at
(
2
));
// Unref the managed resources so they get destroyed properly.
pool_resource
->
Unref
();
}
// The GetSessionOp should take a session stored in the resource manager
// and return it to the ComputeSessionPool.
TEST_F
(
DragnnOpKernelsTest
,
ReleaseSessionOpTest
)
{
...
...
@@ -217,6 +432,56 @@ TEST_F(DragnnOpKernelsTest, ReleaseSessionOpTest) {
EXPECT_EQ
(
null_resource
,
nullptr
);
}
// The GetSessionCounts op should report the number of sessions created and
// free.
TEST_F
(
DragnnOpKernelsTest
,
GetSessionCountsOpTest
)
{
// Create and initialize the kernel under test.
TF_ASSERT_OK
(
NodeDefBuilder
(
"get_session_counts"
,
"GetSessionCounts"
)
.
Input
(
FakeInput
(
DT_STRING
))
// The handle for the ComputeSession.
.
Finalize
(
node_def
()));
TF_ASSERT_OK
(
InitOp
());
// Set the input data.
const
string
container_string
=
"container_str"
;
AddInputFromList
<
string
>
(
TensorShape
({
1
}),
{
container_string
});
// Reset the test context to ensure it's clean.
ResetOpKernelContext
();
// Create a ComputeSessionPool.
MasterSpec
spec
;
GridPoint
hyperparams
;
std
::
unique_ptr
<
ComputeSessionPool
>
pool
(
new
ComputeSessionPool
(
spec
,
hyperparams
));
// Get an unowned pointer to the ComputeSessionPool before moving
// the pool to the resource manager.
ComputeSessionPool
*
pool_ptr
=
pool
.
get
();
TF_ASSERT_OK
(
resource_mgr
()
->
Create
<
ComputeSessionPoolResource
>
(
container_string
,
"pool"
,
new
ComputeSessionPoolResource
(
std
::
move
(
pool
))));
// Create two ComputeSessions.
auto
session_one
=
pool_ptr
->
GetSession
();
auto
session_two
=
pool_ptr
->
GetSession
();
// Retun one of them.
pool_ptr
->
ReturnSession
(
std
::
move
(
session_two
));
// At this point, the pool should report that it has one outstanding session
// and two sessions total.
EXPECT_EQ
(
1
,
pool_ptr
->
num_outstanding_sessions
());
EXPECT_EQ
(
2
,
pool_ptr
->
num_unique_sessions
());
// Run the kernel.
TF_EXPECT_OK
(
RunOpKernelWithContext
());
EXPECT_EQ
(
pool_ptr
->
num_unique_sessions
(),
GetOutput
(
0
)
->
vec
<
int64
>
()(
0
));
EXPECT_EQ
(
pool_ptr
->
num_outstanding_sessions
(),
GetOutput
(
0
)
->
vec
<
int64
>
()(
1
));
}
// The AdvanceFromOracle op should call AdvanceFromOracle on the specified
// component name.
TEST_F
(
DragnnOpKernelsTest
,
AdvanceFromOracleOpTest
)
{
...
...
@@ -287,14 +552,65 @@ TEST_F(DragnnOpKernelsTest, AdvanceFromPredictionOpTest) {
// Set expectations on the mock session.
auto
validator_function
=
[
weights
](
const
string
&
component_name
,
const
float
score_matrix
[],
int
score_matrix_length
)
{
EXPECT_EQ
(
weights
.
size
(),
score_matrix_length
);
const
float
*
score_matrix
,
int
num_items
,
int
num_actions
)
{
EXPECT_EQ
(
weights
.
size
(),
num_items
*
num_actions
);
for
(
int
i
=
0
;
i
<
weights
.
size
();
++
i
)
{
EXPECT_EQ
(
weights
[
i
],
score_matrix
[
i
]);
}
return
true
;
};
EXPECT_CALL
(
*
mock_session_ptr
,
AdvanceFromPrediction
(
component_name
,
_
,
_
,
_
))
.
WillOnce
(
Invoke
(
validator_function
));
// Run the kernel.
TF_EXPECT_OK
(
RunOpKernelWithContext
());
}
// The AdvanceFromPredicton op should call AdvanceFromPrediction on the
// specified component with the passed scores. If it returns false, the op
// should not return OK.
TEST_F
(
DragnnOpKernelsTest
,
AdvanceFromPredictionFailureTest
)
{
// Create and initialize the kernel under test.
const
string
component_name
=
"TESTING_COMPONENT_NAME"
;
TF_ASSERT_OK
(
NodeDefBuilder
(
"advance_from_prediction"
,
"AdvanceFromPrediction"
)
.
Attr
(
"component"
,
component_name
)
.
Input
(
FakeInput
(
DT_STRING
))
// The handle for the ComputeSession.
.
Input
(
FakeInput
(
DT_FLOAT
))
// The prediction tensor.
.
Finalize
(
node_def
()));
TF_ASSERT_OK
(
InitOp
());
// Set the input data.
const
string
container_string
=
"container_str"
;
const
string
id_string
=
"id_str"
;
AddInputFromList
<
string
>
(
TensorShape
({
2
}),
{
container_string
,
id_string
});
const
std
::
vector
<
float
>
weights
=
{
1.1
,
2.2
,
3.3
,
4.4
};
AddInputFromArray
<
float
>
(
TensorShape
({
2
,
2
}),
weights
);
// Reset the test context to ensure it's clean.
ResetOpKernelContext
();
// Create a MockComputeSession and set expectations.
std
::
unique_ptr
<
MockComputeSession
>
mock_session
(
new
MockComputeSession
());
MockComputeSession
*
mock_session_ptr
=
mock_session
.
get
();
// Wrap the ComputeSessionResource and put it into the resource manager.
TF_ASSERT_OK
(
resource_mgr
()
->
Create
<
ComputeSessionResource
>
(
container_string
,
id_string
,
new
ComputeSessionResource
(
std
::
move
(
mock_session
))));
// Set expectations on the mock session.
auto
validator_function
=
[
weights
](
const
string
&
component_name
,
const
float
*
score_matrix
,
int
num_items
,
int
num_actions
)
{
EXPECT_EQ
(
weights
.
size
(),
num_items
*
num_actions
);
for
(
int
i
=
0
;
i
<
weights
.
size
();
++
i
)
{
EXPECT_EQ
(
weights
[
i
],
score_matrix
[
i
]);
}
return
true
;
};
EXPECT_CALL
(
*
mock_session_ptr
,
AdvanceFromPrediction
(
component_name
,
_
,
_
))
EXPECT_CALL
(
*
mock_session_ptr
,
AdvanceFromPrediction
(
component_name
,
_
,
_
,
_
))
.
WillOnce
(
Invoke
(
validator_function
));
// Run the kernel.
...
...
research/syntaxnet/dragnn/core/ops/dragnn_ops.cc
View file @
4364390a
...
...
@@ -18,6 +18,20 @@
namespace
syntaxnet
{
namespace
dragnn
{
REGISTER_OP
(
"SetAssetDirectory"
)
.
Input
(
"asset_directory: string"
)
.
Output
(
"asset_directory_out: string"
)
.
SetIsStateful
()
.
Doc
(
R"doc(
Override the paths to assets specified in the MasterSpec with the given
asset_directory. This op must be called before any calls to GetSession, as it
will create a new session pool with the overridden master spec.
asset_directory: The directory containing all the assets. Note that all assets
must be in a single flat directory.
asset_directory_out: The input, just as an output.
)doc"
);
REGISTER_OP
(
"GetSession"
)
.
Input
(
"container: string"
)
.
Attr
(
"master_spec: string"
)
...
...
@@ -42,6 +56,18 @@ This ComputeSession will no longer be available after this op returns.
handle: A handle to a ComputeSession that will be returned to the backing pool.
)doc"
);
REGISTER_OP
(
"GetSessionCounts"
)
.
Input
(
"container: string"
)
.
Output
(
"stats: int64"
)
.
SetIsStateful
()
.
Doc
(
R"doc(
Given a container string, output session counts for that ComputeSessionPool.
container: A unique identifier for the ComputeSessionPool to analyze.
stats: A vector of stats. [0] is the total number of created sessions. [1] is
the number of sessions that are currently not in the pool.
)doc"
);
REGISTER_OP
(
"InitComponentData"
)
.
Input
(
"handle: string"
)
.
Input
(
"beam_size: int32"
)
...
...
@@ -123,28 +149,6 @@ component: The name of a Component instance, matching the ComponentSpec.name.
output_handle: A handle to the same ComputeSession after advancement.
)doc"
);
REGISTER_OP
(
"DragnnEmbeddingInitializer"
)
.
Output
(
"embeddings: float"
)
.
Attr
(
"embedding_input: string"
)
.
Attr
(
"vocab: string"
)
.
Attr
(
"scaling_coefficient: float = 1.0"
)
.
Attr
(
"seed: int = 0"
)
.
Attr
(
"seed2: int = 0"
)
.
Doc
(
R"doc(
*** PLACEHOLDER OP - FUNCTIONALITY NOT YET IMPLEMENTED ***
Read embeddings from an an input for every key specified in a text vocab file.
embeddings: A tensor containing embeddings from the specified sstable.
embedding_input: Path to location with embedding vectors.
vocab: Path to list of keys corresponding to the input.
scaling_coefficient: A scaling coefficient for the embedding matrix.
seed: If either `seed` or `seed2` are set to be non-zero, the random number
generator is seeded by the given seed. Otherwise, it is seeded by a
random seed.
seed2: A second seed to avoid seed collision.
)doc"
);
REGISTER_OP
(
"ExtractFixedFeatures"
)
.
Input
(
"handle: string"
)
.
Output
(
"indices: int32"
)
...
...
research/syntaxnet/dragnn/core/resource_container.h
View file @
4364390a
...
...
@@ -13,8 +13,8 @@
// limitations under the License.
// =============================================================================
#ifndef
NLP_SAFT_OPENSOURCE_
DRAGNN_CORE_RESOURCE_CONTAINER_H_
#define
NLP_SAFT_OPENSOURCE_
DRAGNN_CORE_RESOURCE_CONTAINER_H_
#ifndef DRAGNN_CORE_RESOURCE_CONTAINER_H_
#define DRAGNN_CORE_RESOURCE_CONTAINER_H_
#include <memory>
...
...
@@ -48,4 +48,4 @@ class ResourceContainer : public tensorflow::ResourceBase {
}
// namespace dragnn
}
// namespace syntaxnet
#endif //
NLP_SAFT_OPENSOURCE_
DRAGNN_CORE_RESOURCE_CONTAINER_H_
#endif // DRAGNN_CORE_RESOURCE_CONTAINER_H_
research/syntaxnet/dragnn/core/test/BUILD
View file @
4364390a
...
...
@@ -26,6 +26,7 @@ cc_library(
deps
=
[
"//dragnn/components/util:bulk_feature_extractor"
,
"//dragnn/core:compute_session"
,
"//dragnn/core:input_batch_cache"
,
"//dragnn/protos:data_proto"
,
"//dragnn/protos:spec_proto"
,
"//syntaxnet:base"
,
...
...
research/syntaxnet/dragnn/core/test/generic.h
View file @
4364390a
...
...
@@ -13,8 +13,8 @@
// limitations under the License.
// =============================================================================
#ifndef
NLP_SAFT_OPENSOURCE_
DRAGNN_CORE_TEST_GENERIC_H_
#define
NLP_SAFT_OPENSOURCE_
DRAGNN_CORE_TEST_GENERIC_H_
#ifndef DRAGNN_CORE_TEST_GENERIC_H_
#define DRAGNN_CORE_TEST_GENERIC_H_
#include <utility>
...
...
@@ -31,10 +31,18 @@ MATCHER_P(EqualsProto, a, "Protos are not equivalent:") {
return
a
.
DebugString
()
==
arg
.
DebugString
();
}
// Matches an error status whose message matches |substr|.
MATCHER_P
(
IsErrorWithSubstr
,
substr
,
string
(
negation
?
"isn't"
:
"is"
)
+
" an error Status whose message matches the substring '"
+
::
testing
::
PrintToString
(
substr
)
+
"'"
)
{
return
!
arg
.
ok
()
&&
arg
.
error_message
().
find
(
substr
)
!=
string
::
npos
;
}
// Returns the prefix for where the test data is stored.
string
GetTestDataPrefix
();
}
// namespace test
}
// namespace syntaxnet
#endif //
NLP_SAFT_OPENSOURCE_
DRAGNN_CORE_TEST_GENERIC_H_
#endif // DRAGNN_CORE_TEST_GENERIC_H_
research/syntaxnet/dragnn/core/test/mock_component.h
View file @
4364390a
...
...
@@ -13,8 +13,8 @@
// limitations under the License.
// =============================================================================
#ifndef
NLP_SAFT_OPENSOURCE_
DRAGNN_CORE_TEST_MOCK_COMPONENT_H_
#define
NLP_SAFT_OPENSOURCE_
DRAGNN_CORE_TEST_MOCK_COMPONENT_H_
#ifndef DRAGNN_CORE_TEST_MOCK_COMPONENT_H_
#define DRAGNN_CORE_TEST_MOCK_COMPONENT_H_
#include <gmock/gmock.h>
...
...
@@ -47,8 +47,8 @@ class MockComponent : public Component {
MOCK_CONST_METHOD3
(
GetBeamIndexAtStep
,
int
(
int
step
,
int
current_index
,
int
batch
));
MOCK_CONST_METHOD2
(
GetSourceBeamIndex
,
int
(
int
current_index
,
int
batch
));
MOCK_METHOD
2
(
AdvanceFromPrediction
,
void
(
const
float
transition_matrix
[],
int
matrix_length
));
MOCK_METHOD
3
(
AdvanceFromPrediction
,
bool
(
const
float
*
transition_matrix
,
int
num_items
,
int
num_actions
));
MOCK_METHOD0
(
AdvanceFromOracle
,
void
());
MOCK_CONST_METHOD0
(
IsTerminal
,
bool
());
MOCK_METHOD0
(
GetBeam
,
std
::
vector
<
std
::
vector
<
const
TransitionState
*>>
());
...
...
@@ -59,6 +59,11 @@ class MockComponent : public Component {
int
channel_id
));
MOCK_METHOD1
(
BulkGetFixedFeatures
,
int
(
const
BulkFeatureExtractor
&
extractor
));
MOCK_METHOD5
(
BulkEmbedFixedFeatures
,
void
(
int
batch_size_padding
,
int
num_steps_padding
,
int
output_array_size
,
const
vector
<
const
float
*>
&
per_channel_embeddings
,
float
*
embedding_output
));
MOCK_CONST_METHOD1
(
GetRawLinkFeatures
,
std
::
vector
<
LinkFeatures
>
(
int
channel_id
));
MOCK_CONST_METHOD0
(
GetOracleLabels
,
std
::
vector
<
std
::
vector
<
int
>>
());
...
...
@@ -75,4 +80,4 @@ class MockComponent : public Component {
}
// namespace dragnn
}
// namespace syntaxnet
#endif //
NLP_SAFT_OPENSOURCE_
DRAGNN_CORE_TEST_MOCK_COMPONENT_H_
#endif // DRAGNN_CORE_TEST_MOCK_COMPONENT_H_
research/syntaxnet/dragnn/core/test/mock_compute_session.h
View file @
4364390a
...
...
@@ -13,16 +13,18 @@
// limitations under the License.
// =============================================================================
#ifndef
NLP_SAFT_OPENSOURCE_
DRAGNN_CORE_TEST_MOCK_COMPUTE_SESSION_H_
#define
NLP_SAFT_OPENSOURCE_
DRAGNN_CORE_TEST_MOCK_COMPUTE_SESSION_H_
#ifndef DRAGNN_CORE_TEST_MOCK_COMPUTE_SESSION_H_
#define DRAGNN_CORE_TEST_MOCK_COMPUTE_SESSION_H_
#include <
gmock/gmock.h
>
#include <
memory
>
#include "dragnn/components/util/bulk_feature_extractor.h"
#include "dragnn/core/compute_session.h"
#include "dragnn/core/input_batch_cache.h"
#include "dragnn/protos/data.pb.h"
#include "dragnn/protos/spec.pb.h"
#include "syntaxnet/base.h"
#include <gmock/gmock.h>
#include "tensorflow/core/platform/test.h"
namespace
syntaxnet
{
...
...
@@ -40,9 +42,9 @@ class MockComputeSession : public ComputeSession {
MOCK_METHOD2
(
SourceComponentBeamSize
,
int
(
const
string
&
component_name
,
int
channel_id
));
MOCK_METHOD1
(
AdvanceFromOracle
,
void
(
const
string
&
component_name
));
MOCK_METHOD
3
(
AdvanceFromPrediction
,
void
(
const
string
&
component_name
,
const
float
score_matrix
[]
,
int
score_matrix_length
));
MOCK_METHOD
4
(
AdvanceFromPrediction
,
bool
(
const
string
&
component_name
,
const
float
*
score_matrix
,
int
num_items
,
int
num_actions
));
MOCK_CONST_METHOD5
(
GetInputFeatures
,
int
(
const
string
&
component_name
,
std
::
function
<
int32
*
(
int
)
>
allocate_indices
,
...
...
@@ -52,6 +54,11 @@ class MockComputeSession : public ComputeSession {
MOCK_METHOD2
(
BulkGetInputFeatures
,
int
(
const
string
&
component_name
,
const
BulkFeatureExtractor
&
extractor
));
MOCK_METHOD6
(
BulkEmbedFixedFeatures
,
void
(
const
string
&
component_name
,
int
batch_size_padding
,
int
num_steps_padding
,
int
output_array_size
,
const
vector
<
const
float
*>
&
per_channel_embedding
,
float
*
embedding_output
));
MOCK_METHOD2
(
GetTranslatedLinkFeatures
,
std
::
vector
<
LinkFeatures
>
(
const
string
&
component_name
,
int
channel_id
));
...
...
@@ -68,9 +75,17 @@ class MockComputeSession : public ComputeSession {
MOCK_CONST_METHOD1
(
GetDescription
,
string
(
const
string
&
component_name
));
MOCK_CONST_METHOD1
(
Translators
,
const
std
::
vector
<
const
IndexTranslator
*>
(
const
string
&
component_name
));
MOCK_CONST_METHOD1
(
GetReadiedComponent
,
Component
*
(
const
string
&
name
));
// TODO(googleuser): Upgrade gMock to a version that supports mocking methods
// with move-only types, then remove this workaround.
MOCK_METHOD1
(
DoSetInputBatchCache
,
void
(
InputBatchCache
*
batch
));
void
SetInputBatchCache
(
std
::
unique_ptr
<
InputBatchCache
>
batch
)
override
{
DoSetInputBatchCache
(
batch
.
get
());
}
};
}
// namespace dragnn
}
// namespace syntaxnet
#endif //
NLP_SAFT_OPENSOURCE_
DRAGNN_CORE_TEST_MOCK_COMPUTE_SESSION_H_
#endif // DRAGNN_CORE_TEST_MOCK_COMPUTE_SESSION_H_
research/syntaxnet/dragnn/core/test/mock_transition_state.h
View file @
4364390a
...
...
@@ -13,8 +13,8 @@
// limitations under the License.
// =============================================================================
#ifndef
NLP_SAFT_OPENSOURCE_
DRAGNN_CORE_TEST_MOCK_TRANSITION_STATE_H_
#define
NLP_SAFT_OPENSOURCE_
DRAGNN_CORE_TEST_MOCK_TRANSITION_STATE_H_
#ifndef DRAGNN_CORE_TEST_MOCK_TRANSITION_STATE_H_
#define DRAGNN_CORE_TEST_MOCK_TRANSITION_STATE_H_
#include <memory>
...
...
@@ -31,15 +31,17 @@ class MockTransitionState : public TransitionState {
public:
MOCK_METHOD1
(
Init
,
void
(
const
TransitionState
&
parent
));
MOCK_CONST_METHOD0
(
Clone
,
std
::
unique_ptr
<
TransitionState
>
());
MOCK_CONST_METHOD0
(
ParentBeamIndex
,
const
int
());
MOCK_METHOD1
(
SetBeamIndex
,
void
(
const
int
index
));
MOCK_CONST_METHOD0
(
GetBeamIndex
,
const
int
());
MOCK_CONST_METHOD0
(
GetScore
,
const
float
());
MOCK_METHOD1
(
SetScore
,
void
(
const
float
score
));
MOCK_CONST_METHOD0
(
ParentBeamIndex
,
int
());
MOCK_METHOD1
(
SetBeamIndex
,
void
(
int
index
));
MOCK_CONST_METHOD0
(
GetBeamIndex
,
int
());
MOCK_CONST_METHOD0
(
GetScore
,
float
());
MOCK_METHOD1
(
SetScore
,
void
(
float
score
));
MOCK_CONST_METHOD0
(
IsGold
,
bool
());
MOCK_METHOD1
(
SetGold
,
void
(
bool
is_gold
));
MOCK_CONST_METHOD0
(
HTMLRepresentation
,
string
());
};
}
// namespace dragnn
}
// namespace syntaxnet
#endif //
NLP_SAFT_OPENSOURCE_
DRAGNN_CORE_TEST_MOCK_TRANSITION_STATE_H_
#endif // DRAGNN_CORE_TEST_MOCK_TRANSITION_STATE_H_
research/syntaxnet/dragnn/core/testdata/master_spec_link.textproto
View file @
4364390a
...
...
@@ -11,10 +11,6 @@ component {
key: "language"
value: "en"
}
parameters {
key: "neurosis_feature_syntax_version"
value: "2"
}
parameters {
key: "parser_skip_deterministic"
value: "false"
...
...
research/syntaxnet/dragnn/io/sentence_input_batch.h
View file @
4364390a
...
...
@@ -13,8 +13,8 @@
// limitations under the License.
// =============================================================================
#ifndef
NLP_SAFT_OPENSOURCE_
DRAGNN_IO_SENTENCE_INPUT_BATCH_H_
#define
NLP_SAFT_OPENSOURCE_
DRAGNN_IO_SENTENCE_INPUT_BATCH_H_
#ifndef DRAGNN_IO_SENTENCE_INPUT_BATCH_H_
#define DRAGNN_IO_SENTENCE_INPUT_BATCH_H_
#include <string>
#include <vector>
...
...
@@ -35,6 +35,9 @@ class SentenceInputBatch : public InputBatch {
void
SetData
(
const
std
::
vector
<
string
>
&
stringified_sentence_protos
)
override
;
// Returns the size of the batch.
int
GetSize
()
const
override
{
return
data_
.
size
();
}
// Translates to a vector of stringified Sentence protos.
const
std
::
vector
<
string
>
GetSerializedData
()
const
override
;
...
...
@@ -49,4 +52,4 @@ class SentenceInputBatch : public InputBatch {
}
// namespace dragnn
}
// namespace syntaxnet
#endif //
NLP_SAFT_OPENSOURCE_
DRAGNN_IO_SENTENCE_INPUT_BATCH_H_
#endif // DRAGNN_IO_SENTENCE_INPUT_BATCH_H_
research/syntaxnet/dragnn/io/sentence_input_batch_test.cc
View file @
4364390a
...
...
@@ -49,6 +49,9 @@ TEST(SentenceInputBatchTest, ConvertsFromStringifiedProtos) {
EXPECT_NE
(
converted_data
->
at
(
i
).
workspace
(),
nullptr
);
}
// Check the batch size.
EXPECT_EQ
(
strings
.
size
(),
set
.
GetSize
());
// Get the data back out. The strings should be identical.
auto
output
=
set
.
GetSerializedData
();
EXPECT_EQ
(
output
.
size
(),
strings
.
size
());
...
...
research/syntaxnet/dragnn/io/syntaxnet_sentence.h
View file @
4364390a
...
...
@@ -13,8 +13,8 @@
// limitations under the License.
// =============================================================================
#ifndef
NLP_SAFT_OPENSOURCE_
DRAGNN_IO_SYNTAXNET_SENTENCE_H_
#define
NLP_SAFT_OPENSOURCE_
DRAGNN_IO_SYNTAXNET_SENTENCE_H_
#ifndef DRAGNN_IO_SYNTAXNET_SENTENCE_H_
#define DRAGNN_IO_SYNTAXNET_SENTENCE_H_
#include "syntaxnet/sentence.pb.h"
#include "syntaxnet/workspace.h"
...
...
@@ -39,4 +39,4 @@ class SyntaxNetSentence {
}
// namespace dragnn
}
// namespace syntaxnet
#endif //
NLP_SAFT_OPENSOURCE_
DRAGNN_IO_SYNTAXNET_SENTENCE_H_
#endif // DRAGNN_IO_SYNTAXNET_SENTENCE_H_
research/syntaxnet/dragnn/protos/BUILD
View file @
4364390a
...
...
@@ -26,6 +26,12 @@ tf_proto_library(
srcs
=
[
"spec.proto"
],
)
tf_proto_library
(
name
=
"runtime_proto"
,
srcs
=
[
"runtime.proto"
],
deps
=
[
":spec_proto"
],
)
tf_proto_library_py
(
name
=
"data_py_pb2"
,
srcs
=
[
"data.proto"
],
...
...
research/syntaxnet/dragnn/protos/runtime.proto
0 → 100644
View file @
4364390a
syntax
=
"proto2"
;
import
"dragnn/protos/spec.proto"
;
package
syntaxnet
.
dragnn.runtime
;
// Performance tuning settings that only affect resource usage, not annotated
// output or correctness. This should be attached to the MasterSpec used to
// initialize a Master.
//
// NEXT ID: 2
message
MasterPerformanceSettings
{
extend
MasterSpec
{
optional
MasterPerformanceSettings
master_spec_extension
=
160848628
;
}
// Maximum size of the free list in the SessionStatePool. NB: The default
// value may occasionally change.
optional
uint64
session_state_pool_max_free_states
=
1
[
default
=
4
];
}
// As above, but for component-specific performance tuning settings.
//
// NEXT ID: 2
message
ComponentPerformanceSettings
{
extend
ComponentSpec
{
optional
ComponentPerformanceSettings
component_spec_extension
=
160999422
;
}
// Number of steps to pre-allocate for the relevant component. NB: The
// default value may occasionally change.
optional
uint32
pre_allocate_num_steps
=
1
[
default
=
50
];
}
// Specification of an ArrayVariableStore.
//
// NEXT ID: 5
message
ArrayVariableStoreSpec
{
// Characteristics of the variable data. The binary that loads the variables
// must match these characteristics.
optional
uint32
version
=
1
;
// required version of the byte array format
optional
uint32
alignment_bytes
=
2
;
// required alignment of the byte array
optional
bool
is_little_endian
=
3
;
// required endian-ness of the byte array
// Variable specifications, in order of appearance in the byte array.
repeated
VariableSpec
variable
=
4
;
}
// Specification of a single serialized variable.
//
// NEXT ID: 6
message
VariableSpec
{
// Formats for serialized pre-trained variables. See VariableStore::Lookup()
// for descriptions of the enumerators.
enum
Format
{
FORMAT_UNKNOWN
=
0
;
FORMAT_FLAT
=
1
;
FORMAT_ROW_MAJOR_MATRIX
=
2
;
FORMAT_COLUMN_BLOCKED_ROW_MAJOR_MATRIX
=
3
;
}
// Name of the variable.
optional
string
name
=
1
;
// Format of the variable.
optional
Format
format
=
2
[
default
=
FORMAT_UNKNOWN
];
// Dimensions of variables. The semantics depends on the format, but is always
// in logical units (number of floats, etc.) rather than bytes,
//
// * flat: single value with the length of the vector
// * row-major and column-major: two values, [rows, columns]
// * row-blocked column-major: three values, [rows, columns, row_block_size]
repeated
uint32
dimension
=
5
;
// Number of sub-views in the AlignedArea that contained the variable.
optional
uint64
num_views
=
3
;
// Sub-view size in bytes for the AlignedArea that contained the variable.
optional
uint64
view_size
=
4
;
}
research/syntaxnet/dragnn/protos/spec.proto
View file @
4364390a
...
...
@@ -16,6 +16,7 @@ message MasterSpec {
// Whether to extract debug traces.
optional
bool
debug_tracing
=
4
[
default
=
false
];
extensions
1000
to
max
;
reserved
2
,
3
,
5
;
}
...
...
@@ -28,8 +29,7 @@ message ComponentSpec {
// TransitionSystem to use.
optional
RegisteredModuleSpec
transition_system
=
2
;
// Resources that this component depends on. These are copied to TaskInputs
// when calling SAFT code.
// Resources that this component depends on.
repeated
Resource
resource
=
3
;
// Feature space configurations.
...
...
@@ -58,6 +58,8 @@ message ComponentSpec {
// Default max number of active states for beam inference.
optional
int32
inference_beam_size
=
12
[
default
=
1
];
extensions
1000
to
max
;
}
// Super generic container for any registered sub-piece of DRAGNN.
...
...
@@ -65,14 +67,11 @@ message RegisteredModuleSpec {
// Name of the registered class.
optional
string
registered_name
=
1
;
// Parameters to set while initializing this system; these are copied to
// Parameters in a TaskSpec when calling SAFT code, or via kwargs in TF Python
// code.
// Parameters to set while initializing this system.
map
<
string
,
string
>
parameters
=
2
;
}
// Fixed resources that will be converted into TaskInput's when calling SAFT
// code.
// Fixed resource.
message
Resource
{
optional
string
name
=
1
;
repeated
Part
part
=
2
;
...
...
@@ -218,6 +217,9 @@ message GridPoint {
optional
double
gradient_clip_norm
=
11
[
default
=
0.0
];
// A spec for using multiple optimization methods.
//
// This is not guaranteed to work for recursively-defined composite
// optimizers.
message
CompositeOptimizerSpec
{
// First optimizer.
optional
GridPoint
method1
=
1
;
...
...
@@ -227,6 +229,11 @@ message GridPoint {
// After this number of steps, switch from first to second.
optional
int32
switch_after_steps
=
3
;
// Whether to reset the learning rate (which normally decays) after
// switching optimizers. Limitations: It will only reset to the initial
// learning rate, and won't work for recursively-defined optimizers.
optional
bool
reset_learning_rate
=
4
[
default
=
false
];
}
optional
CompositeOptimizerSpec
composite_optimizer_spec
=
12
;
...
...
@@ -247,6 +254,7 @@ message GridPoint {
// place. Typically a single component.
optional
string
self_norm_components_filter
=
21
;
extensions
1000
to
max
;
reserved
5
,
6
;
}
...
...
research/syntaxnet/dragnn/python/BUILD
View file @
4364390a
...
...
@@ -16,6 +16,11 @@ cc_binary(
],
)
filegroup
(
name
=
"testdata"
,
data
=
glob
([
"testdata/**"
]),
)
py_library
(
name
=
"load_dragnn_cc_impl_py"
,
srcs
=
[
"load_dragnn_cc_impl.py"
],
...
...
@@ -64,7 +69,51 @@ py_library(
py_library
(
name
=
"dragnn_ops"
,
srcs
=
[
"dragnn_ops.py"
],
deps
=
[],
deps
=
[
":load_dragnn_cc_impl_py"
,
"//dragnn/core:dragnn_bulk_ops"
,
"//dragnn/core:dragnn_ops"
,
"//syntaxnet:load_parser_ops_py"
,
],
)
py_library
(
name
=
"dragnn_model_saver_lib"
,
srcs
=
[
"dragnn_model_saver_lib.py"
],
deps
=
[
":dragnn_ops"
,
":graph_builder"
,
":load_dragnn_cc_impl_py"
,
":network_units"
,
"//dragnn/protos:spec_py_pb2"
,
"//syntaxnet:load_parser_ops_py"
,
"//syntaxnet:sentence_py_pb2"
,
"@org_tensorflow//tensorflow:tensorflow_py"
,
"@org_tensorflow//tensorflow/core:protos_all_py"
,
],
)
py_test
(
name
=
"dragnn_model_saver_lib_test"
,
srcs
=
[
"dragnn_model_saver_lib_test.py"
],
data
=
[
":testdata"
],
deps
=
[
":dragnn_model_saver_lib"
,
"//dragnn/protos:spec_py_pb2"
,
"@org_tensorflow//tensorflow:tensorflow_py"
,
],
)
py_binary
(
name
=
"dragnn_model_saver"
,
srcs
=
[
"dragnn_model_saver.py"
],
deps
=
[
":dragnn_model_saver_lib"
,
":spec_builder"
,
"//dragnn/protos:spec_py_pb2"
,
"@org_tensorflow//tensorflow:tensorflow_py"
,
"@org_tensorflow//tensorflow/core:protos_all_py"
,
],
)
py_library
(
...
...
@@ -76,6 +125,7 @@ py_library(
":composite_optimizer"
,
":dragnn_ops"
,
":network_units"
,
":transformer_units"
,
":wrapped_units"
,
"//dragnn/protos:spec_py_pb2"
,
"//syntaxnet/util:check"
,
...
...
@@ -184,10 +234,7 @@ py_test(
":bulk_component"
,
":components"
,
":dragnn_ops"
,
":load_dragnn_cc_impl_py"
,
":network_units"
,
"//dragnn/core:dragnn_bulk_ops"
,
"//dragnn/core:dragnn_ops"
,
"//dragnn/protos:spec_py_pb2"
,
"//syntaxnet:load_parser_ops_py"
,
"//syntaxnet:sentence_py_pb2"
,
...
...
@@ -201,7 +248,6 @@ py_test(
srcs
=
[
"composite_optimizer_test.py"
],
deps
=
[
":composite_optimizer"
,
":load_dragnn_cc_impl_py"
,
"//dragnn/core:dragnn_bulk_ops"
,
"//dragnn/core:dragnn_ops"
,
"//syntaxnet:load_parser_ops_py"
,
...
...
@@ -217,15 +263,13 @@ py_test(
data
=
[
"//dragnn/core:testdata"
,
],
shard_count
=
5
,
tags
=
[
"notsan"
,
],
deps
=
[
":dragnn_ops"
,
":graph_builder"
,
":load_dragnn_cc_impl_py"
,
"//dragnn/core:dragnn_bulk_ops"
,
"//dragnn/core:dragnn_ops"
,
"//dragnn/protos:spec_py_pb2"
,
"//dragnn/protos:trace_py_pb2"
,
"//syntaxnet:load_parser_ops_py"
,
...
...
@@ -240,7 +284,6 @@ py_test(
size
=
"small"
,
srcs
=
[
"network_units_test.py"
],
deps
=
[
":load_dragnn_cc_impl_py"
,
":network_units"
,
"//dragnn/core:dragnn_bulk_ops"
,
"//dragnn/core:dragnn_ops"
,
...
...
@@ -256,6 +299,7 @@ py_test(
srcs
=
[
"sentence_io_test.py"
],
data
=
[
"//syntaxnet:testdata"
],
deps
=
[
":dragnn_ops"
,
":sentence_io"
,
"//syntaxnet:load_parser_ops_py"
,
"//syntaxnet:parser_ops"
,
...
...
@@ -373,3 +417,30 @@ py_library(
"@org_tensorflow//tensorflow:tensorflow_py"
,
],
)
py_library
(
name
=
"transformer_units"
,
srcs
=
[
"transformer_units.py"
],
deps
=
[
":network_units"
,
"//syntaxnet/util:check"
,
"@org_tensorflow//tensorflow:tensorflow_py"
,
],
)
py_test
(
name
=
"transformer_units_test"
,
size
=
"small"
,
srcs
=
[
"transformer_units_test.py"
],
deps
=
[
":network_units"
,
":transformer_units"
,
"//dragnn/core:dragnn_bulk_ops"
,
"//dragnn/core:dragnn_ops"
,
"//dragnn/protos:spec_py_pb2"
,
"//syntaxnet:load_parser_ops_py"
,
"@org_tensorflow//tensorflow:tensorflow_py"
,
"@org_tensorflow//tensorflow/core:protos_all_py"
,
],
)
research/syntaxnet/dragnn/python/biaffine_units.py
View file @
4364390a
...
...
@@ -95,7 +95,7 @@ class BiaffineDigraphNetwork(network_units.NetworkUnitInterface):
self
.
_regularized_weights
.
extend
(
self
.
_weights
)
# Negative Layer.dim indicates that the dimension is dynamic.
self
.
_layers
.
append
(
network_units
.
Layer
(
self
,
'adjacency'
,
-
1
))
self
.
_layers
.
append
(
network_units
.
Layer
(
component
,
'adjacency'
,
-
1
))
def
create
(
self
,
fixed_embeddings
,
...
...
@@ -209,7 +209,8 @@ class BiaffineLabelNetwork(network_units.NetworkUnitInterface):
self
.
_params
.
extend
(
self
.
_weights
+
self
.
_biases
)
self
.
_regularized_weights
.
extend
(
self
.
_weights
)
self
.
_layers
.
append
(
network_units
.
Layer
(
self
,
'labels'
,
self
.
_num_labels
))
self
.
_layers
.
append
(
network_units
.
Layer
(
component
,
'labels'
,
self
.
_num_labels
))
def
create
(
self
,
fixed_embeddings
,
...
...
research/syntaxnet/dragnn/python/bulk_component.py
View file @
4364390a
...
...
@@ -216,9 +216,11 @@ def build_cross_entropy_loss(logits, gold):
logits
=
tf
.
gather
(
logits
,
valid
)
correct
=
tf
.
reduce_sum
(
tf
.
to_int32
(
tf
.
nn
.
in_top_k
(
logits
,
gold
,
1
)))
total
=
tf
.
size
(
gold
)
cost
=
tf
.
reduce_sum
(
tf
.
contrib
.
nn
.
deprecated_flipped_sparse_softmax_cross_entropy_with_logits
(
logits
,
tf
.
cast
(
gold
,
tf
.
int64
)))
/
tf
.
cast
(
total
,
tf
.
float32
)
with
tf
.
control_dependencies
([
tf
.
assert_positive
(
total
)]):
cost
=
tf
.
reduce_sum
(
tf
.
nn
.
sparse_softmax_cross_entropy_with_logits
(
labels
=
tf
.
cast
(
gold
,
tf
.
int64
),
logits
=
logits
))
/
tf
.
cast
(
total
,
tf
.
float32
)
return
cost
,
correct
,
total
...
...
@@ -267,6 +269,22 @@ class BulkFeatureExtractorComponentBuilder(component.ComponentBuilderBase):
correct
,
total
=
tf
.
constant
(
0
),
tf
.
constant
(
0
)
return
state
.
handle
,
cost
,
correct
,
total
def
build_post_restore_hook
(
self
):
"""Builds a graph that should be executed after the restore op.
This graph is intended to be run once, before the inference pipeline is
run.
Returns:
setup_op - An op that, when run, guarantees all setup ops will run.
"""
logging
.
info
(
'Building restore hook for component: %s'
,
self
.
spec
.
name
)
with
tf
.
variable_scope
(
self
.
name
):
if
callable
(
getattr
(
self
.
network
,
'build_post_restore_hook'
,
None
)):
return
[
self
.
network
.
build_post_restore_hook
()]
else
:
return
[]
def
build_greedy_inference
(
self
,
state
,
network_states
,
during_training
=
False
):
"""Extracts features and advances a batch using the oracle path.
...
...
Prev
1
2
3
4
5
6
7
…
9
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment