Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
80178fc6
Unverified
Commit
80178fc6
authored
May 11, 2018
by
Mark Omernick
Committed by
GitHub
May 11, 2018
Browse files
Merge pull request #4153 from terryykoo/master
Export @195097388.
parents
a84e1ef9
edea2b67
Changes
145
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1187 additions
and
280 deletions
+1187
-280
research/syntaxnet/dragnn/core/compute_session_impl.h
research/syntaxnet/dragnn/core/compute_session_impl.h
+15
-5
research/syntaxnet/dragnn/core/compute_session_impl_test.cc
research/syntaxnet/dragnn/core/compute_session_impl_test.cc
+36
-200
research/syntaxnet/dragnn/core/compute_session_pool.cc
research/syntaxnet/dragnn/core/compute_session_pool.cc
+22
-14
research/syntaxnet/dragnn/core/compute_session_pool.h
research/syntaxnet/dragnn/core/compute_session_pool.h
+10
-5
research/syntaxnet/dragnn/core/index_translator.cc
research/syntaxnet/dragnn/core/index_translator.cc
+1
-1
research/syntaxnet/dragnn/core/interfaces/BUILD
research/syntaxnet/dragnn/core/interfaces/BUILD
+3
-2
research/syntaxnet/dragnn/core/interfaces/component.h
research/syntaxnet/dragnn/core/interfaces/component.h
+15
-1
research/syntaxnet/dragnn/core/interfaces/transition_state.h
research/syntaxnet/dragnn/core/interfaces/transition_state.h
+2
-2
research/syntaxnet/dragnn/core/ops/compute_session_op.cc
research/syntaxnet/dragnn/core/ops/compute_session_op.cc
+4
-0
research/syntaxnet/dragnn/core/ops/dragnn_bulk_op_kernels.cc
research/syntaxnet/dragnn/core/ops/dragnn_bulk_op_kernels.cc
+89
-6
research/syntaxnet/dragnn/core/ops/dragnn_bulk_op_kernels_test.cc
.../syntaxnet/dragnn/core/ops/dragnn_bulk_op_kernels_test.cc
+9
-5
research/syntaxnet/dragnn/core/ops/dragnn_bulk_ops.cc
research/syntaxnet/dragnn/core/ops/dragnn_bulk_ops.cc
+77
-8
research/syntaxnet/dragnn/core/ops/dragnn_op_kernels.cc
research/syntaxnet/dragnn/core/ops/dragnn_op_kernels.cc
+274
-8
research/syntaxnet/dragnn/core/ops/dragnn_op_kernels_test.cc
research/syntaxnet/dragnn/core/ops/dragnn_op_kernels_test.cc
+289
-16
research/syntaxnet/dragnn/core/ops/dragnn_ops.cc
research/syntaxnet/dragnn/core/ops/dragnn_ops.cc
+148
-1
research/syntaxnet/dragnn/core/ops/shape_helpers.h
research/syntaxnet/dragnn/core/ops/shape_helpers.h
+55
-0
research/syntaxnet/dragnn/core/test/BUILD
research/syntaxnet/dragnn/core/test/BUILD
+12
-4
research/syntaxnet/dragnn/core/test/fake_component_base.h
research/syntaxnet/dragnn/core/test/fake_component_base.h
+106
-0
research/syntaxnet/dragnn/core/test/generic.h
research/syntaxnet/dragnn/core/test/generic.h
+12
-1
research/syntaxnet/dragnn/core/test/mock_component.h
research/syntaxnet/dragnn/core/test/mock_component.h
+8
-1
No files found.
research/syntaxnet/dragnn/core/compute_session_impl.h
View file @
80178fc6
...
...
@@ -16,20 +16,23 @@
#ifndef DRAGNN_CORE_COMPUTE_SESSION_IMPL_H_
#define DRAGNN_CORE_COMPUTE_SESSION_IMPL_H_
#include <map>
#include <memory>
#include "dragnn/components/util/bulk_feature_extractor.h"
#include "dragnn/core/compute_session.h"
#include "dragnn/core/index_translator.h"
#include "dragnn/core/input_batch_cache.h"
#include "dragnn/core/util/label.h"
#include "dragnn/protos/data.pb.h"
#include "dragnn/protos/spec.pb.h"
#include "dragnn/protos/trace.pb.h"
namespace
syntaxnet
{
namespace
dragnn
{
class
ComputeSessionImpl
:
public
ComputeSession
{
class
ComputeSessionImpl
final
:
public
ComputeSession
{
public:
// Creates a ComputeSessionImpl with the provided component builder function.
ComputeSessionImpl
(
...
...
@@ -77,7 +80,7 @@ class ComputeSessionImpl : public ComputeSession {
std
::
vector
<
LinkFeatures
>
GetTranslatedLinkFeatures
(
const
string
&
component_name
,
int
channel_id
)
override
;
std
::
vector
<
std
::
vector
<
int
>>
EmitOracleLabels
(
std
::
vector
<
std
::
vector
<
std
::
vector
<
Label
>
>>
EmitOracleLabels
(
const
string
&
component_name
)
override
;
bool
IsTerminal
(
const
string
&
component_name
)
override
;
...
...
@@ -92,6 +95,8 @@ class ComputeSessionImpl : public ComputeSession {
void
SetInputBatchCache
(
std
::
unique_ptr
<
InputBatchCache
>
batch
)
override
;
InputBatchCache
*
GetInputBatchCache
()
override
;
void
ResetSession
()
override
;
void
SetTracing
(
bool
tracing_on
)
override
;
...
...
@@ -108,6 +113,11 @@ class ComputeSessionImpl : public ComputeSession {
Component
*
GetReadiedComponent
(
const
string
&
component_name
)
const
override
;
private:
// Mapping from Keys to Values.
template
<
class
Key
,
class
Value
>
using
Mapping
=
std
::
map
<
Key
,
Value
>
;
// Get a given component. Fails if the component is not found.
Component
*
GetComponent
(
const
string
&
component_name
)
const
;
...
...
@@ -124,11 +134,11 @@ class ComputeSessionImpl : public ComputeSession {
// Holds all of the components owned by this ComputeSession, associated with
// their names in the MasterSpec.
std
::
map
<
string
,
std
::
unique_ptr
<
Component
>>
components_
;
Mapping
<
string
,
std
::
unique_ptr
<
Component
>>
components_
;
// Holds a vector of translators for each component, indexed by the name
// of the component they belong to.
std
::
map
<
string
,
std
::
vector
<
IndexTranslator
*>>
translators_
;
Mapping
<
string
,
std
::
vector
<
IndexTranslator
*>>
translators_
;
// Holds ownership of all the IndexTranslators for this compute session.
std
::
vector
<
std
::
unique_ptr
<
IndexTranslator
>>
owned_translators_
;
...
...
@@ -136,7 +146,7 @@ class ComputeSessionImpl : public ComputeSession {
// The predecessor component for every component.
// If a component is not in this map, it has no predecessor component and
// will have its beam initialized without any data from other components.
std
::
map
<
Component
*
,
Component
*>
predecessors_
;
Mapping
<
Component
*
,
Component
*>
predecessors_
;
// Holds the current input data for this ComputeSession.
std
::
unique_ptr
<
InputBatchCache
>
input_data_
;
...
...
research/syntaxnet/dragnn/core/compute_session_impl_test.cc
View file @
80178fc6
...
...
@@ -25,240 +25,49 @@
#include "dragnn/core/input_batch_cache.h"
#include "dragnn/core/interfaces/component.h"
#include "dragnn/core/interfaces/input_batch.h"
#include "dragnn/core/test/fake_component_base.h"
#include "dragnn/core/test/generic.h"
#include "dragnn/core/test/mock_component.h"
#include "dragnn/core/test/mock_transition_state.h"
#include "dragnn/core/util/label.h"
#include "tensorflow/core/platform/test.h"
namespace
syntaxnet
{
namespace
dragnn
{
using
syntaxnet
::
test
::
EqualsProto
;
using
testing
::
_
;
using
testing
::
ElementsAre
;
using
testing
::
Return
;
using
testing
::
NotNull
;
using
testing
::
Return
;
using
testing
::
_
;
// *****************************************************************************
// Test-internal class definitions.
// *****************************************************************************
// Define a test component to validate registered construction.
class
TestComponentType1
:
public
Component
{
class
TestComponentType1
:
public
FakeComponentBase
{
public:
TestComponentType1
()
{}
void
InitializeComponent
(
const
ComponentSpec
&
spec
)
override
{
name_
=
spec
.
name
();
}
void
InitializeData
(
const
std
::
vector
<
std
::
vector
<
const
TransitionState
*>>
&
states
,
int
max_beam_size
,
InputBatchCache
*
input_data
)
override
{}
void
InitializeTracing
()
override
{}
void
DisableTracing
()
override
{}
bool
IsReady
()
const
override
{
return
true
;
}
string
Name
()
const
override
{
return
name_
;
}
int
BeamSize
()
const
override
{
return
3
;
}
int
BatchSize
()
const
override
{
return
1
;
}
int
StepsTaken
(
int
batch_index
)
const
override
{
return
0
;
}
int
GetBeamIndexAtStep
(
int
step
,
int
current_index
,
int
batch
)
const
override
{
return
0
;
}
int
GetSourceBeamIndex
(
int
current_index
,
int
batch
)
const
override
{
return
0
;
}
bool
AdvanceFromPrediction
(
const
float
*
score_matrix
,
int
num_items
,
int
num_actions
)
override
{
return
true
;
}
void
AdvanceFromOracle
()
override
{}
bool
IsTerminal
()
const
override
{
return
true
;
}
std
::
function
<
int
(
int
,
int
,
int
)
>
GetStepLookupFunction
(
const
string
&
method
)
override
{
return
nullptr
;
}
std
::
vector
<
std
::
vector
<
const
TransitionState
*>>
GetBeam
()
override
{
std
::
vector
<
std
::
vector
<
const
TransitionState
*>>
states
;
return
states
;
}
int
GetFixedFeatures
(
std
::
function
<
int32
*
(
int
)
>
allocate_indices
,
std
::
function
<
int64
*
(
int
)
>
allocate_ids
,
std
::
function
<
float
*
(
int
)
>
allocate_weights
,
int
channel_id
)
const
override
{
return
0
;
}
void
BulkEmbedFixedFeatures
(
int
batch_size_padding
,
int
num_steps_padding
,
int
embedding_size
,
const
vector
<
const
float
*>
&
per_channel_embeddings
,
float
*
embedding_output
)
override
{}
int
BulkGetFixedFeatures
(
const
BulkFeatureExtractor
&
extractor
)
override
{
return
0
;
}
std
::
vector
<
LinkFeatures
>
GetRawLinkFeatures
(
int
channel_id
)
const
override
{
std
::
vector
<
LinkFeatures
>
ret
;
return
ret
;
}
std
::
vector
<
std
::
vector
<
int
>>
GetOracleLabels
()
const
override
{
std
::
vector
<
std
::
vector
<
int
>>
ret
;
return
ret
;
}
void
FinalizeData
()
override
{}
void
ResetComponent
()
override
{}
std
::
vector
<
std
::
vector
<
ComponentTrace
>>
GetTraceProtos
()
const
override
{
std
::
vector
<
std
::
vector
<
ComponentTrace
>>
ret
;
return
ret
;
}
void
AddTranslatedLinkFeaturesToTrace
(
const
std
::
vector
<
LinkFeatures
>
&
features
,
int
channel_id
)
override
{}
string
name_
;
};
REGISTER_DRAGNN_COMPONENT
(
TestComponentType1
);
// Define a second test component to validate registered construction.
class
TestComponentType2
:
public
Component
{
class
TestComponentType2
:
public
FakeComponentBase
{
public:
TestComponentType2
()
{}
void
InitializeComponent
(
const
ComponentSpec
&
spec
)
override
{
name_
=
spec
.
name
();
}
void
InitializeData
(
const
std
::
vector
<
std
::
vector
<
const
TransitionState
*>>
&
states
,
int
max_beam_size
,
InputBatchCache
*
input_data
)
override
{}
void
InitializeTracing
()
override
{}
void
DisableTracing
()
override
{}
bool
IsReady
()
const
override
{
return
true
;
}
string
Name
()
const
override
{
return
name_
;
}
int
BeamSize
()
const
override
{
return
4
;
}
int
BatchSize
()
const
override
{
return
2
;
}
int
StepsTaken
(
int
batch_index
)
const
override
{
return
0
;
}
int
GetBeamIndexAtStep
(
int
step
,
int
current_index
,
int
batch
)
const
override
{
return
0
;
}
int
GetSourceBeamIndex
(
int
current_index
,
int
batch
)
const
override
{
return
0
;
}
bool
AdvanceFromPrediction
(
const
float
*
score_matrix
,
int
num_items
,
int
num_actions
)
override
{
return
true
;
}
void
AdvanceFromOracle
()
override
{}
bool
IsTerminal
()
const
override
{
return
true
;
}
std
::
function
<
int
(
int
,
int
,
int
)
>
GetStepLookupFunction
(
const
string
&
method
)
override
{
return
nullptr
;
}
std
::
vector
<
std
::
vector
<
const
TransitionState
*>>
GetBeam
()
override
{
std
::
vector
<
std
::
vector
<
const
TransitionState
*>>
states
;
return
states
;
}
int
GetFixedFeatures
(
std
::
function
<
int32
*
(
int
)
>
allocate_indices
,
std
::
function
<
int64
*
(
int
)
>
allocate_ids
,
std
::
function
<
float
*
(
int
)
>
allocate_weights
,
int
channel_id
)
const
override
{
return
0
;
}
void
BulkEmbedFixedFeatures
(
int
batch_size_padding
,
int
num_steps_padding
,
int
embedding_size
,
const
vector
<
const
float
*>
&
per_channel_embeddings
,
float
*
embedding_output
)
override
{}
int
BulkGetFixedFeatures
(
const
BulkFeatureExtractor
&
extractor
)
override
{
return
0
;
}
std
::
vector
<
LinkFeatures
>
GetRawLinkFeatures
(
int
channel_id
)
const
override
{
std
::
vector
<
LinkFeatures
>
ret
;
return
ret
;
}
std
::
vector
<
std
::
vector
<
int
>>
GetOracleLabels
()
const
override
{
std
::
vector
<
std
::
vector
<
int
>>
ret
;
return
ret
;
}
void
FinalizeData
()
override
{}
void
ResetComponent
()
override
{}
std
::
vector
<
std
::
vector
<
ComponentTrace
>>
GetTraceProtos
()
const
override
{
std
::
vector
<
std
::
vector
<
ComponentTrace
>>
ret
;
return
ret
;
}
void
AddTranslatedLinkFeaturesToTrace
(
const
std
::
vector
<
LinkFeatures
>
&
features
,
int
channel_id
)
override
{}
string
name_
;
};
REGISTER_DRAGNN_COMPONENT
(
TestComponentType2
);
// Define a component that returns false for IsReady and IsTerminal.
class
UnreadyComponent
:
public
Component
{
class
UnreadyComponent
:
public
Fake
Component
Base
{
public:
UnreadyComponent
()
{}
void
InitializeComponent
(
const
ComponentSpec
&
spec
)
override
{
name_
=
spec
.
name
();
}
void
InitializeData
(
const
std
::
vector
<
std
::
vector
<
const
TransitionState
*>>
&
states
,
int
max_beam_size
,
InputBatchCache
*
input_data
)
override
{}
void
InitializeTracing
()
override
{}
void
DisableTracing
()
override
{}
bool
IsReady
()
const
override
{
return
false
;
}
string
Name
()
const
override
{
return
name_
;
}
int
BeamSize
()
const
override
{
return
1
;
}
int
BatchSize
()
const
override
{
return
2
;
}
int
StepsTaken
(
int
batch_index
)
const
override
{
return
0
;
}
int
GetBeamIndexAtStep
(
int
step
,
int
current_index
,
int
batch
)
const
override
{
return
0
;
}
int
GetSourceBeamIndex
(
int
current_index
,
int
batch
)
const
override
{
return
0
;
}
bool
AdvanceFromPrediction
(
const
float
*
score_matrix
,
int
num_items
,
int
num_actions
)
override
{
return
true
;
}
void
BulkEmbedFixedFeatures
(
int
batch_size_padding
,
int
num_steps_padding
,
int
embedding_size
,
const
vector
<
const
float
*>
&
per_channel_embeddings
,
float
*
embedding_output
)
override
{}
void
AdvanceFromOracle
()
override
{}
bool
IsTerminal
()
const
override
{
return
false
;
}
std
::
function
<
int
(
int
,
int
,
int
)
>
GetStepLookupFunction
(
const
string
&
method
)
override
{
return
nullptr
;
}
std
::
vector
<
std
::
vector
<
const
TransitionState
*>>
GetBeam
()
override
{
std
::
vector
<
std
::
vector
<
const
TransitionState
*>>
states
;
return
states
;
}
int
GetFixedFeatures
(
std
::
function
<
int32
*
(
int
)
>
allocate_indices
,
std
::
function
<
int64
*
(
int
)
>
allocate_ids
,
std
::
function
<
float
*
(
int
)
>
allocate_weights
,
int
channel_id
)
const
override
{
return
0
;
}
int
BulkGetFixedFeatures
(
const
BulkFeatureExtractor
&
extractor
)
override
{
return
0
;
}
std
::
vector
<
LinkFeatures
>
GetRawLinkFeatures
(
int
channel_id
)
const
override
{
std
::
vector
<
LinkFeatures
>
ret
;
return
ret
;
}
std
::
vector
<
std
::
vector
<
int
>>
GetOracleLabels
()
const
override
{
std
::
vector
<
std
::
vector
<
int
>>
ret
;
return
ret
;
}
void
FinalizeData
()
override
{}
void
ResetComponent
()
override
{}
std
::
vector
<
std
::
vector
<
ComponentTrace
>>
GetTraceProtos
()
const
override
{
std
::
vector
<
std
::
vector
<
ComponentTrace
>>
ret
;
return
ret
;
}
void
AddTranslatedLinkFeaturesToTrace
(
const
std
::
vector
<
LinkFeatures
>
&
features
,
int
channel_id
)
override
{}
string
name_
;
};
REGISTER_DRAGNN_COMPONENT
(
UnreadyComponent
);
...
...
@@ -850,7 +659,7 @@ TEST(ComputeSessionImplTest,
// The death expectation is interacting strangely with this test, so I need
// to wrap the function in a lambda.
EXPECT_DEATH
(
function_that_will_die
(),
"
Source
is not terminal"
);
EXPECT_DEATH
(
function_that_will_die
(),
"is not terminal"
);
}
TEST
(
ComputeSessionImplTest
,
ResetSessionResetsAllComponents
)
{
...
...
@@ -1147,7 +956,10 @@ TEST(ComputeSessionImplTest, InterfacePassesThrough) {
session
->
BulkEmbedFixedFeatures
(
"component_one"
,
1
,
2
,
3
,
{
nullptr
},
nullptr
);
// EmitOracleLabels()
std
::
vector
<
std
::
vector
<
int
>>
oracle_labels
=
{{
0
,
1
},
{
2
,
3
}};
// The size of oracle_labels is batch_size * beam_size * num_labels.
const
std
::
vector
<
std
::
vector
<
std
::
vector
<
Label
>>>
oracle_labels
{
{{{
0
,
1.
f
}},
{{
1
,
1.
f
}}},
{{{
2
,
1.
f
}},
{{
3
,
1.
f
}}}};
EXPECT_CALL
(
*
mock_components
[
"component_one"
],
GetOracleLabels
())
.
WillOnce
(
Return
(
oracle_labels
));
EXPECT_EQ
(
oracle_labels
,
session
->
EmitOracleLabels
(
"component_one"
));
...
...
@@ -1227,5 +1039,29 @@ TEST(ComputeSessionImplTest, SetInputBatchCache) {
EXPECT_EQ
(
session
->
GetSerializedPredictions
(),
data
);
}
TEST
(
ComputeSessionImplTest
,
GetInputBatchCache
)
{
// Use empty protos since we won't interact with components.
MasterSpec
spec
;
GridPoint
hyperparams
;
ComputeSessionPool
pool
(
spec
,
hyperparams
);
auto
session
=
pool
.
GetSession
();
// No input data yet.
EXPECT_EQ
(
session
->
GetInputBatchCache
(),
nullptr
);
// Set some data, expect some batch to be returned.
session
->
SetInputData
({
"arbitrary_data"
});
EXPECT_NE
(
session
->
GetInputBatchCache
(),
nullptr
);
// Create a dummy batch.
const
std
::
vector
<
string
>
data
=
{
"foo"
,
"bar"
,
"baz"
};
std
::
unique_ptr
<
InputBatchCache
>
input_batch_cache
(
new
InputBatchCache
(
data
));
InputBatchCache
*
input_batch_cache_ptr
=
input_batch_cache
.
get
();
// Inject a batch, expect that batch to be returned.
session
->
SetInputBatchCache
(
std
::
move
(
input_batch_cache
));
EXPECT_EQ
(
session
->
GetInputBatchCache
(),
input_batch_cache_ptr
);
}
}
// namespace dragnn
}
// namespace syntaxnet
research/syntaxnet/dragnn/core/compute_session_pool.cc
View file @
80178fc6
...
...
@@ -33,9 +33,9 @@ ComputeSessionPool::ComputeSessionPool(const MasterSpec &master_spec,
num_unique_sessions_
(
0
)
{
// Create a default component builder function. This function looks up
// components in the component registry and returns them.
component_builder_
=
[](
const
string
&
component_name
,
const
string
&
backend_type
)
->
std
::
unique_ptr
<
Component
>
{
component_builder_
=
[](
const
string
&
component_name
,
const
string
&
backend_type
)
->
std
::
unique_ptr
<
Component
>
{
VLOG
(
2
)
<<
"Creating component "
<<
component_name
<<
" with backend "
<<
backend_type
;
std
::
unique_ptr
<
Component
>
component
(
Component
::
Create
(
backend_type
));
...
...
@@ -45,7 +45,7 @@ ComputeSessionPool::ComputeSessionPool(const MasterSpec &master_spec,
// Create a default session builder function. This function returns a
// ComputeSessionImpl that uses the currently set component_builder_
// function to create its components.
session_builder_
=
[
this
]()
{
session_builder_
=
[
this
]()
EXCLUSIVE_LOCKS_REQUIRED
(
lock_
)
{
return
std
::
unique_ptr
<
ComputeSession
>
(
new
ComputeSessionImpl
(
num_unique_sessions_
,
this
->
component_builder_
));
};
...
...
@@ -75,20 +75,28 @@ void ComputeSessionPool::SetComponentBuilder(
}
std
::
unique_ptr
<
ComputeSession
>
ComputeSessionPool
::
GetSession
()
{
mutex_lock
lock
(
lock_
);
std
::
unique_ptr
<
ComputeSession
>
session_ptr
;
if
(
sessions_
.
empty
())
{
// There are no available sessions, so create and initialize one.
bool
is_new
=
false
;
{
// This mutex effectively single-threads the application at this point,
// since all ComputeSessions must call here; to minimize impact, we
// subscope it.
mutex_lock
lock
(
lock_
);
if
(
!
sessions_
.
empty
())
{
VLOG
(
2
)
<<
"Reusing session from pool of size "
<<
sessions_
.
size
();
session_ptr
=
std
::
move
(
sessions_
.
back
());
sessions_
.
pop_back
();
}
else
{
session_ptr
=
session_builder_
();
is_new
=
true
;
num_unique_sessions_
++
;
}
}
if
(
is_new
)
{
VLOG
(
2
)
<<
"Creating new session."
;
session_ptr
=
session_builder_
();
num_unique_sessions_
++
;
session_ptr
->
Init
(
master_spec_
,
hyperparams_
);
}
else
{
// Get the last free session, and remove it from the free sessions vector.
VLOG
(
2
)
<<
"Reusing session from pool of size "
<<
sessions_
.
size
();
session_ptr
=
std
::
move
(
sessions_
.
back
());
sessions_
.
pop_back
();
session_ptr
->
ResetSession
();
}
return
session_ptr
;
...
...
research/syntaxnet/dragnn/core/compute_session_pool.h
View file @
80178fc6
...
...
@@ -21,6 +21,7 @@
#include "dragnn/core/compute_session.h"
#include "dragnn/protos/spec.pb.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/thread_annotations.h"
namespace
syntaxnet
{
namespace
dragnn
{
...
...
@@ -50,7 +51,10 @@ class ComputeSessionPool {
}
// Returns the number of unique sessions that have been created.
int
num_unique_sessions
()
{
return
num_unique_sessions_
;
}
int
num_unique_sessions
()
{
tensorflow
::
mutex_lock
lock
(
lock_
);
return
num_unique_sessions_
;
}
// Returns a reference to the underlying spec for this pool.
const
MasterSpec
&
GetSpec
()
const
{
return
master_spec_
;
}
...
...
@@ -82,21 +86,22 @@ class ComputeSessionPool {
const
GridPoint
hyperparams_
;
// The function that is used to create ComputeSessions.
std
::
function
<
std
::
unique_ptr
<
ComputeSession
>
()
>
session_builder_
;
std
::
function
<
std
::
unique_ptr
<
ComputeSession
>
()
>
session_builder_
GUARDED_BY
(
lock_
);
// The function passed to ComputeSessions that will be used by that session
// to create components.
std
::
function
<
std
::
unique_ptr
<
Component
>
(
const
string
&
component_name
,
const
string
&
backend_type
)
>
component_builder_
;
component_builder_
GUARDED_BY
(
lock_
)
;
// ComputeSessions that are not currently being used. These sessions are not
// reset until they are requested by another thread.
std
::
vector
<
std
::
unique_ptr
<
ComputeSession
>>
sessions_
;
std
::
vector
<
std
::
unique_ptr
<
ComputeSession
>>
sessions_
GUARDED_BY
(
lock_
)
;
// Count of the number of unique ComputeSession objects that have been
// created. Used to assign IDs to new Sessions.
int
num_unique_sessions_
;
int
num_unique_sessions_
GUARDED_BY
(
lock_
)
;
// Mutex that protects accesses to all members of this object.
tensorflow
::
mutex
lock_
;
...
...
research/syntaxnet/dragnn/core/index_translator.cc
View file @
80178fc6
...
...
@@ -33,7 +33,7 @@ IndexTranslator::IndexTranslator(const std::vector<Component *> &path,
}
else
if
(
method_
==
"history"
)
{
// History lookup: Return the number of steps taken less the feature.
step_lookup_
=
[
this
](
int
batch_index
,
int
beam_index
,
int
feature
)
{
if
(
feature
>
path_
.
back
()
->
StepsTaken
(
batch_index
)
-
1
)
{
if
(
feature
>
path_
.
back
()
->
StepsTaken
(
batch_index
)
-
1
||
feature
<
0
)
{
VLOG
(
2
)
<<
"Translation to outside: feature is "
<<
feature
<<
" and steps_taken is "
<<
path_
.
back
()
->
StepsTaken
(
batch_index
);
...
...
research/syntaxnet/dragnn/core/interfaces/BUILD
View file @
80178fc6
...
...
@@ -16,8 +16,9 @@ cc_library(
":transition_state"
,
"//dragnn/components/util:bulk_feature_extractor"
,
"//dragnn/core:input_batch_cache"
,
"//dragnn/protos:spec_proto"
,
"//dragnn/protos:trace_proto"
,
"//dragnn/core/util:label"
,
"//dragnn/protos:spec_proto_cc"
,
"//dragnn/protos:trace_proto_cc"
,
"//syntaxnet:base"
,
"//syntaxnet:registry"
,
],
...
...
research/syntaxnet/dragnn/core/interfaces/component.h
View file @
80178fc6
...
...
@@ -21,6 +21,7 @@
#include "dragnn/components/util/bulk_feature_extractor.h"
#include "dragnn/core/input_batch_cache.h"
#include "dragnn/core/interfaces/transition_state.h"
#include "dragnn/core/util/label.h"
#include "dragnn/protos/spec.pb.h"
#include "dragnn/protos/trace.pb.h"
#include "syntaxnet/registry.h"
...
...
@@ -120,6 +121,18 @@ class Component : public RegisterableClass<Component> {
const
vector
<
const
float
*>
&
per_channel_embeddings
,
float
*
embedding_output
)
=
0
;
// Directly computes the embedding matrix for all channels, advancing the
// component via the oracle until it is terminal. This call takes a vector
// of float matrices containing embeddings, one per channel, in channel order.
// This function outputs a densified right-ragged tensor.
virtual
void
BulkEmbedDenseFixedFeatures
(
const
vector
<
const
float
*>
&
per_channel_embeddings
,
float
*
embedding_output
,
int
embedding_output_size
,
int32
*
offset_array_output
,
int
offset_array_size
)
=
0
;
// Gets the expected size of the data matrix for BulkEmbedDenseFixedFeatures.
virtual
int
BulkDenseFeatureSize
()
const
=
0
;
// Extracts and returns the vector of LinkFeatures for the specified
// channel. Note: these are NOT translated.
virtual
std
::
vector
<
LinkFeatures
>
GetRawLinkFeatures
(
...
...
@@ -127,7 +140,8 @@ class Component : public RegisterableClass<Component> {
// Returns a vector of oracle labels for each element in the beam and
// batch.
virtual
std
::
vector
<
std
::
vector
<
int
>>
GetOracleLabels
()
const
=
0
;
virtual
std
::
vector
<
std
::
vector
<
std
::
vector
<
Label
>>>
GetOracleLabels
()
const
=
0
;
// Annotate the underlying data object with the results of this Component's
// calculation.
...
...
research/syntaxnet/dragnn/core/interfaces/transition_state.h
View file @
80178fc6
...
...
@@ -29,8 +29,8 @@ namespace dragnn {
// another, and every backend should define one. Note that inheriting from
// TransitionState directly is not sufficient to use the Beam class, which
// requires extra functionality given by inheriting from the
// ClonableTransitionState interface. (ClonableTransitionState is a subclass
// of TransitionState, so inheriting from ClonableTransitionState is sufficient
// Clon
e
ableTransitionState interface. (Clon
e
ableTransitionState is a subclass
// of TransitionState, so inheriting from Clon
e
ableTransitionState is sufficient
// to allow Components to pass your backing states.)
class
TransitionState
{
...
...
research/syntaxnet/dragnn/core/ops/compute_session_op.cc
View file @
80178fc6
...
...
@@ -62,6 +62,10 @@ void ComputeSessionOp::Compute(OpKernelContext *context) {
"Must declare at least one output of type string "
"for the ComputeSession handle if OutputsHandle is true."
));
}
OP_REQUIRES
(
context
,
context
->
input
(
0
).
dims
()
==
1
,
InvalidArgument
(
"Input to ComputeSession must be a vector, got rank "
,
context
->
input
(
0
).
dims
()));
// Gets the relevant ComputeSessionResource and computes with it.
auto
handle
=
context
->
input
(
0
).
vec
<
string
>
();
...
...
research/syntaxnet/dragnn/core/ops/dragnn_bulk_op_kernels.cc
View file @
80178fc6
...
...
@@ -20,6 +20,7 @@
#include "dragnn/core/ops/compute_session_op.h"
#include "dragnn/core/resource_container.h"
#include "dragnn/core/util/label.h"
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/numeric_types.h"
#include "tensorflow/core/framework/op_kernel.h"
...
...
@@ -40,10 +41,10 @@ using tensorflow::DataType;
using
tensorflow
::
OpKernel
;
using
tensorflow
::
OpKernelConstruction
;
using
tensorflow
::
OpKernelContext
;
using
tensorflow
::
quint8
;
using
tensorflow
::
Status
;
using
tensorflow
::
Tensor
;
using
tensorflow
::
TensorShape
;
using
tensorflow
::
quint8
;
using
tensorflow
::
uint8
;
namespace
syntaxnet
{
...
...
@@ -335,11 +336,19 @@ class BulkEmbedFixedFeatures : public ComputeSessionOp {
embeddings
[
channel
]
=
context
->
input
(
embeddings_index
).
flat
<
float
>
().
data
();
}
int
batch_size
;
if
(
pad_to_batch_
==
-
1
)
{
batch_size
=
session
->
BatchSize
(
component_name
());
}
else
{
batch_size
=
pad_to_batch_
;
}
VLOG
(
2
)
<<
"batch size: "
<<
batch_size
;
Tensor
*
embedding_vectors
;
OP_REQUIRES_OK
(
context
,
context
->
allocate_output
(
1
,
TensorShape
({
pad_to_steps_
*
pad_to_
batch_
*
TensorShape
({
pad_to_steps_
*
batch_
size
*
session
->
BeamSize
(
component_name
()),
embedding_size
}),
&
embedding_vectors
));
...
...
@@ -348,8 +357,8 @@ class BulkEmbedFixedFeatures : public ComputeSessionOp {
&
num_steps_tensor
));
embedding_vectors
->
flat
<
float
>
().
setZero
();
int
output_size
=
embedding_vectors
->
NumElements
();
session
->
BulkEmbedFixedFeatures
(
component_name
(),
pad_to_
batch
_
,
pad_to_steps_
,
output_size
,
embeddings
,
session
->
BulkEmbedFixedFeatures
(
component_name
(),
batch_size
,
pad_to_
steps
_
,
output_size
,
embeddings
,
embedding_vectors
->
flat
<
float
>
().
data
());
num_steps_tensor
->
scalar
<
int32
>
()()
=
pad_to_steps_
;
}
...
...
@@ -370,6 +379,74 @@ class BulkEmbedFixedFeatures : public ComputeSessionOp {
REGISTER_KERNEL_BUILDER
(
Name
(
"BulkEmbedFixedFeatures"
).
Device
(
DEVICE_CPU
),
BulkEmbedFixedFeatures
);
// See docstring in dragnn_bulk_ops.cc.
class
BulkEmbedDenseFixedFeatures
:
public
ComputeSessionOp
{
public:
explicit
BulkEmbedDenseFixedFeatures
(
OpKernelConstruction
*
context
)
:
ComputeSessionOp
(
context
)
{
OP_REQUIRES_OK
(
context
,
context
->
GetAttr
(
"num_channels"
,
&
num_channels_
));
// The input vector's zeroth element is the state handle, and the remaining
// num_channels_ elements are tensors of float embeddings, one per channel.
std
::
vector
<
DataType
>
input_types
(
num_channels_
+
1
,
DT_FLOAT
);
input_types
[
0
]
=
DT_STRING
;
const
std
::
vector
<
DataType
>
output_types
=
{
DT_STRING
,
DT_FLOAT
,
DT_INT32
};
OP_REQUIRES_OK
(
context
,
context
->
MatchSignature
(
input_types
,
output_types
));
}
bool
OutputsHandle
()
const
override
{
return
true
;
}
bool
RequiresComponentName
()
const
override
{
return
true
;
}
void
ComputeWithState
(
OpKernelContext
*
context
,
ComputeSession
*
session
)
override
{
const
auto
&
spec
=
session
->
Spec
(
component_name
());
int
embedding_size
=
0
;
std
::
vector
<
const
float
*>
embeddings
(
num_channels_
);
for
(
int
channel
=
0
;
channel
<
num_channels_
;
++
channel
)
{
const
int
embeddings_index
=
channel
+
1
;
embedding_size
+=
context
->
input
(
embeddings_index
).
shape
().
dim_size
(
1
)
*
spec
.
fixed_feature
(
channel
).
size
();
embeddings
[
channel
]
=
context
->
input
(
embeddings_index
).
flat
<
float
>
().
data
();
}
auto
component
=
session
->
GetReadiedComponent
(
component_name
());
int
data_tensor_size
=
component
->
BulkDenseFeatureSize
();
Tensor
*
embedding_vectors
;
OP_REQUIRES_OK
(
context
,
context
->
allocate_output
(
1
,
TensorShape
({
data_tensor_size
,
embedding_size
}),
&
embedding_vectors
));
Tensor
*
offset_array_tensor
;
OP_REQUIRES
(
context
,
component
->
BeamSize
()
==
1
,
tensorflow
::
errors
::
FailedPrecondition
(
"Beam must be 1."
));
OP_REQUIRES_OK
(
context
,
context
->
allocate_output
(
2
,
TensorShape
({
component
->
BatchSize
()
+
1
}),
&
offset_array_tensor
));
embedding_vectors
->
flat
<
float
>
().
setZero
();
int
output_size
=
embedding_vectors
->
NumElements
();
int
offset_array_size
=
offset_array_tensor
->
NumElements
();
component
->
BulkEmbedDenseFixedFeatures
(
embeddings
,
embedding_vectors
->
flat
<
float
>
().
data
(),
output_size
,
offset_array_tensor
->
flat
<
int32
>
().
data
(),
offset_array_size
);
}
private:
// Number of fixed feature channels.
int
num_channels_
;
// Will pad output to this many batch elements.
int
pad_to_batch_
;
// Will pad output to this many steps.
int
pad_to_steps_
;
TF_DISALLOW_COPY_AND_ASSIGN
(
BulkEmbedDenseFixedFeatures
);
};
REGISTER_KERNEL_BUILDER
(
Name
(
"BulkEmbedDenseFixedFeatures"
).
Device
(
DEVICE_CPU
),
BulkEmbedDenseFixedFeatures
);
// See docstring in dragnn_bulk_ops.cc.
class
BulkAdvanceFromOracle
:
public
ComputeSessionOp
{
public:
...
...
@@ -388,7 +465,9 @@ class BulkAdvanceFromOracle : public ComputeSessionOp {
const
int
batch_size
=
session
->
BatchSize
(
component_name
());
const
int
beam_size
=
session
->
BeamSize
(
component_name
());
const
int
num_items
=
batch_size
*
beam_size
;
vector
<
vector
<
vector
<
int32
>>>
gold
;
// Nested vector of size step_count * batch_size * beam_size * label_count.
vector
<
vector
<
vector
<
vector
<
Label
>>>>
gold
;
int
num_steps
=
0
;
while
(
!
session
->
IsTerminal
(
component_name
()))
{
...
...
@@ -408,8 +487,12 @@ class BulkAdvanceFromOracle : public ComputeSessionOp {
for
(
int
batch_ix
=
0
;
batch_ix
<
batch_size
;
++
batch_ix
)
{
for
(
int
beam_ix
=
0
;
beam_ix
<
beam_size
;
++
beam_ix
,
++
item
)
{
for
(
int
step
=
0
;
step
<
num_steps
;
++
step
)
{
// The default transition system behavior is a one-hot multi-class
// prediction, so there is only one gold label. If there are more than
// one gold labels, the code assumes they are equally valid, and we
// arbitrarily pick the first one.
gold_output
->
vec
<
int32
>
()(
item
*
num_steps
+
step
)
=
step
<
gold
.
size
()
?
gold
[
step
][
batch_ix
][
beam_ix
]
:
-
1
;
step
<
gold
.
size
()
?
gold
[
step
][
batch_ix
][
beam_ix
]
[
0
].
id
:
-
1
;
}
}
}
...
...
research/syntaxnet/dragnn/core/ops/dragnn_bulk_op_kernels_test.cc
View file @
80178fc6
...
...
@@ -17,6 +17,7 @@
#include "dragnn/core/compute_session_pool.h"
#include "dragnn/core/resource_container.h"
#include "dragnn/core/test/mock_compute_session.h"
#include "dragnn/core/util/label.h"
#include <gmock/gmock.h>
#include "tensorflow/core/framework/fake_input.h"
...
...
@@ -624,13 +625,16 @@ TEST_F(DragnnBulkOpKernelsTest, BulkAdvanceFromOracle) {
.
WillOnce
(
Return
(
true
));
EXPECT_CALL
(
*
mock_session
,
AdvanceFromOracle
(
kComponentName
))
.
Times
(
kNumSteps
);
const
vector
<
vector
<
vector
<
int32
>>>
gold
=
{
{{
1
},
{
1
},
{
1
}},
{{
2
},
{
2
},
{
2
}},
{{
3
},
{
3
},
{
3
}},
const
std
::
vector
<
std
::
vector
<
std
::
vector
<
std
::
vector
<
Label
>>>>
gold_labels
{
{{{{
1
,
1.
f
}}},
{{{
1
,
1.
f
}}},
{{{
1
,
1.
f
}}}},
{{{{
2
,
1.
f
}}},
{{{
2
,
1.
f
}}},
{{{
2
,
1.
f
}}}},
{{{{
3
,
1.
f
}}},
{{{
3
,
1.
f
}}},
{{{
3
,
1.
f
}}}},
};
EXPECT_CALL
(
*
mock_session
,
EmitOracleLabels
(
kComponentName
))
.
WillOnce
(
Return
(
gold
[
0
]))
.
WillOnce
(
Return
(
gold
[
1
]))
.
WillOnce
(
Return
(
gold
[
2
]));
.
WillOnce
(
Return
(
gold
_labels
[
0
]))
.
WillOnce
(
Return
(
gold
_labels
[
1
]))
.
WillOnce
(
Return
(
gold
_labels
[
2
]));
EXPECT_CALL
(
*
mock_session
,
BeamSize
(
kComponentName
)).
WillOnce
(
Return
(
1
));
EXPECT_CALL
(
*
mock_session
,
BatchSize
(
kComponentName
))
.
WillOnce
(
Return
(
kNumItems
));
...
...
research/syntaxnet/dragnn/core/ops/dragnn_bulk_ops.cc
View file @
80178fc6
...
...
@@ -13,6 +13,7 @@
// limitations under the License.
// =============================================================================
#include "dragnn/core/ops/shape_helpers.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/shape_inference.h"
...
...
@@ -28,6 +29,15 @@ REGISTER_OP("BulkFixedFeatures")
.
Output
(
"num_steps: int32"
)
.
Attr
(
"component: string"
)
.
Attr
(
"num_channels: int"
)
.
SetShapeFn
([](
tensorflow
::
shape_inference
::
InferenceContext
*
context
)
{
int
num_channels
;
TF_RETURN_IF_ERROR
(
context
->
GetAttr
(
"num_channels"
,
&
num_channels
));
for
(
int
i
=
1
;
i
<=
3
*
num_channels
;
++
i
)
{
VectorOutputShape
(
i
,
context
);
}
ScalarOutputShape
(
3
*
num_channels
+
1
,
context
);
return
ComputeSessionHandleInputAndOutputShape
(
context
);
})
.
Doc
(
R"doc(
Given a ComputeSession and a component, outputs fixed features for all steps.
...
...
@@ -60,6 +70,16 @@ REGISTER_OP("BulkFixedEmbeddings")
.
Attr
(
"pad_to_batch: int=-1"
)
.
Attr
(
"pad_to_steps: int=-1"
)
.
SetIsStateful
()
.
SetShapeFn
([](
tensorflow
::
shape_inference
::
InferenceContext
*
context
)
{
int
num_channels
;
TF_RETURN_IF_ERROR
(
context
->
GetAttr
(
"num_channels"
,
&
num_channels
));
for
(
int
i
=
1
;
i
<=
num_channels
;
++
i
)
{
TF_RETURN_IF_ERROR
(
MatrixInputShape
(
i
,
context
));
}
MatrixOutputShape
(
1
,
context
);
ScalarOutputShape
(
2
,
context
);
return
ComputeSessionHandleInputAndOutputShape
(
context
);
})
.
Doc
(
R"doc(
This op is a more efficient version of BulkFixedFeatures.
...
...
@@ -91,6 +111,16 @@ REGISTER_OP("BulkEmbedFixedFeatures")
.
Attr
(
"pad_to_batch: int"
)
.
Attr
(
"pad_to_steps: int"
)
.
SetIsStateful
()
.
SetShapeFn
([](
tensorflow
::
shape_inference
::
InferenceContext
*
context
)
{
int
num_channels
;
TF_RETURN_IF_ERROR
(
context
->
GetAttr
(
"num_channels"
,
&
num_channels
));
for
(
int
i
=
1
;
i
<=
num_channels
;
++
i
)
{
TF_RETURN_IF_ERROR
(
MatrixInputShape
(
i
,
context
));
}
MatrixOutputShape
(
1
,
context
);
ScalarOutputShape
(
2
,
context
);
return
ComputeSessionHandleInputAndOutputShape
(
context
);
})
.
Doc
(
R"doc(
This op is a more efficient version of BulkFixedFeatures.
...
...
@@ -112,11 +142,55 @@ pad_to_batch: The op will pad/truncate to this number of elements.
pad_to_steps: The op will pad/truncate to this number of steps.
)doc"
);
REGISTER_OP
(
"BulkEmbedDenseFixedFeatures"
)
.
Input
(
"handle: string"
)
.
Input
(
"embedding_matrix: num_channels * float"
)
.
Output
(
"output_handle: string"
)
.
Output
(
"embedding_vectors: float"
)
.
Output
(
"offset_array: int32"
)
.
Attr
(
"component: string"
)
.
Attr
(
"num_channels: int"
)
.
SetIsStateful
()
.
SetShapeFn
([](
tensorflow
::
shape_inference
::
InferenceContext
*
context
)
{
int
num_channels
;
TF_RETURN_IF_ERROR
(
context
->
GetAttr
(
"num_channels"
,
&
num_channels
));
for
(
int
i
=
1
;
i
<=
num_channels
;
++
i
)
{
TF_RETURN_IF_ERROR
(
MatrixInputShape
(
i
,
context
));
}
MatrixOutputShape
(
1
,
context
);
VectorOutputShape
(
2
,
context
);
return
ComputeSessionHandleInputAndOutputShape
(
context
);
})
.
Doc
(
R"doc(
This op is a more efficient version of BulkFixedFeatures.
It is intended to be run with large batch sizes at inference time. The op takes
a handle to ComputeSession and embedding matrices as tensor inputs, and directly
outputs concatenated embedding vectors. It calls the BulkEmbedFixedFeatures
method on the underlying component directly, so it requires a padding vector
to be passed.
handle: A handle to ComputeSession.
embedding_matrix: Embedding matrices.
output_handle: A handle to the same ComputeSession after advancement.
embedding_vectors: (matrix of float) Concatenated embeddings, in a dense
array.
offset_array: An array of integers representing the offset of each batch element
in the embedding_vectors array. It is of size (batch+1) and the last element is
the total size of the embedding array.
component: The name of a Component instance, matching the ComponentSpec.name.
num_channels: The number of FixedFeature channels.
)doc"
);
REGISTER_OP
(
"BulkAdvanceFromOracle"
)
.
Input
(
"handle: string"
)
.
Output
(
"output_handle: string"
)
.
Output
(
"gold_labels: int32"
)
.
Attr
(
"component: string"
)
.
SetShapeFn
([](
tensorflow
::
shape_inference
::
InferenceContext
*
context
)
{
VectorOutputShape
(
1
,
context
);
return
ComputeSessionHandleInputAndOutputShape
(
context
);
})
.
Doc
(
R"doc(
Given a ComputeSession, advances until all states are final.
...
...
@@ -140,14 +214,9 @@ REGISTER_OP("BulkAdvanceFromPrediction")
.
Output
(
"output_handle: string"
)
.
Attr
(
"component: string"
)
.
Attr
(
"T: type"
)
.
SetShapeFn
([](
tensorflow
::
shape_inference
::
InferenceContext
*
c
)
{
tensorflow
::
shape_inference
::
ShapeHandle
handle
;
TF_RETURN_IF_ERROR
(
c
->
Merge
(
c
->
input
(
0
),
c
->
Vector
(
2
),
&
handle
));
c
->
set_output
(
0
,
handle
);
auto
scores
=
c
->
input
(
1
);
TF_RETURN_IF_ERROR
(
c
->
WithRank
(
scores
,
2
,
&
scores
));
return
tensorflow
::
Status
::
OK
();
.
SetShapeFn
([](
tensorflow
::
shape_inference
::
InferenceContext
*
context
)
{
TF_RETURN_IF_ERROR
(
MatrixInputShape
(
1
,
context
));
return
ComputeSessionHandleInputAndOutputShape
(
context
);
})
.
Doc
(
R"doc(
Given a ComputeSession and a tensor of scores, advances the state.
...
...
research/syntaxnet/dragnn/core/ops/dragnn_op_kernels.cc
View file @
80178fc6
...
...
@@ -21,6 +21,7 @@
#include "dragnn/core/compute_session_pool.h"
#include "dragnn/core/ops/compute_session_op.h"
#include "dragnn/core/resource_container.h"
#include "dragnn/core/util/label.h"
#include "dragnn/protos/data.pb.h"
#include "dragnn/protos/spec.pb.h"
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
...
...
@@ -41,8 +42,6 @@ using tensorflow::DT_INT32;
using
tensorflow
::
DT_INT64
;
using
tensorflow
::
DT_STRING
;
using
tensorflow
::
DataType
;
using
tensorflow
::
io
::
Dirname
;
using
tensorflow
::
io
::
JoinPath
;
using
tensorflow
::
OpKernel
;
using
tensorflow
::
OpKernelConstruction
;
using
tensorflow
::
OpKernelContext
;
...
...
@@ -50,6 +49,8 @@ using tensorflow::ResourceMgr;
using
tensorflow
::
Status
;
using
tensorflow
::
Tensor
;
using
tensorflow
::
TensorShape
;
using
tensorflow
::
io
::
Dirname
;
using
tensorflow
::
io
::
JoinPath
;
namespace
syntaxnet
{
namespace
dragnn
{
...
...
@@ -330,6 +331,209 @@ class GetSessionCounts : public OpKernel {
REGISTER_KERNEL_BUILDER
(
Name
(
"GetSessionCounts"
).
Device
(
DEVICE_CPU
),
GetSessionCounts
);
// Rebatches a dense ragged tensor into a batch of padded subsequences.
class
RebatchDensor
:
public
OpKernel
{
public:
explicit
RebatchDensor
(
OpKernelConstruction
*
context
)
:
OpKernel
(
context
)
{
OP_REQUIRES_OK
(
context
,
context
->
GetAttr
(
"sequence_length"
,
&
sequence_length_
));
OP_REQUIRES_OK
(
context
,
context
->
GetAttr
(
"lr_padding"
,
&
lr_padding_
));
OP_REQUIRES_OK
(
context
,
context
->
MatchSignature
({
DT_FLOAT
,
DT_INT32
},
{
DT_FLOAT
,
DT_INT32
}));
OP_REQUIRES
(
context
,
lr_padding_
<
sequence_length_
,
tensorflow
::
errors
::
FailedPrecondition
(
"Sequence length must be longer than padding."
));
}
void
Compute
(
OpKernelContext
*
context
)
override
{
// Figure out how many sequences we need.
const
Tensor
&
data
=
context
->
input
(
0
);
const
int
embedding_size
=
data
.
shape
().
dim_size
(
1
);
const
Tensor
&
offsets
=
context
->
input
(
1
);
const
int
offsets_size
=
offsets
.
shape
().
dim_size
(
0
);
const
int
batch_size
=
offsets_size
-
1
;
const
auto
&
offset_data
=
offsets
.
vec
<
int32
>
();
int
num_elements
=
0
;
for
(
int
i
=
0
;
i
<
batch_size
;
++
i
)
{
int
element_length
=
offset_data
(
i
+
1
)
-
offset_data
(
i
);
if
(
element_length
>
0
)
{
int
num_full_sequences
=
element_length
/
sequence_length_
;
int
length
=
((
element_length
%
sequence_length_
)
==
0
)
?
(
num_full_sequences
)
:
(
num_full_sequences
+
1
);
num_elements
+=
length
;
VLOG
(
2
)
<<
"Item "
<<
i
<<
" of length "
<<
element_length
<<
" will use "
<<
length
<<
". Total: "
<<
num_elements
;
}
}
int
output_sequence_length
=
2
*
lr_padding_
+
sequence_length_
;
VLOG
(
2
)
<<
"Rebatch shape: "
<<
num_elements
<<
" "
<<
output_sequence_length
<<
" "
<<
embedding_size
;
// Allocate the output tensors.
Tensor
*
output
;
OP_REQUIRES_OK
(
context
,
context
->
allocate_output
(
0
,
TensorShape
({
num_elements
,
output_sequence_length
,
embedding_size
}),
&
output
));
output
->
flat
<
float
>
().
setZero
();
Tensor
*
indices
;
OP_REQUIRES_OK
(
context
,
context
->
allocate_output
(
1
,
TensorShape
({
num_elements
}),
&
indices
));
const
float
*
dense_data
=
data
.
flat
<
float
>
().
data
();
float
*
output_data
=
output
->
flat
<
float
>
().
data
();
int64
start_offset
=
lr_padding_
*
embedding_size
;
int64
seq_max_length
=
lr_padding_
+
sequence_length_
;
int64
row_index
=
0
;
for
(
int
i
=
0
;
i
<
batch_size
;
++
i
)
{
int64
element_length
=
offset_data
(
i
+
1
)
-
offset_data
(
i
);
VLOG
(
2
)
<<
"Rebatching index "
<<
i
<<
" with size "
<<
element_length
;
if
(
element_length
==
0
)
{
continue
;
}
int64
first_seq_length
=
std
::
min
(
element_length
,
seq_max_length
);
int64
subseqence_length
=
first_seq_length
*
embedding_size
;
int64
dense_start
=
offset_data
(
i
)
*
embedding_size
;
int64
output_start
=
row_index
*
output_sequence_length
*
embedding_size
+
start_offset
;
for
(
int
j
=
0
;
j
<
subseqence_length
;
++
j
)
{
output_data
[
output_start
+
j
]
=
dense_data
[
dense_start
+
j
];
}
indices
->
vec
<
int32
>
()(
row_index
)
=
i
;
VLOG
(
2
)
<<
"Rebatched "
<<
i
<<
" to "
<<
row_index
;
++
row_index
;
int64
tokens_remaining
=
element_length
-
sequence_length_
;
VLOG
(
2
)
<<
"Remaining: "
<<
tokens_remaining
;
while
(
tokens_remaining
>
0
)
{
int64
seq_length
=
std
::
min
(
tokens_remaining
,
seq_max_length
);
int64
subseqence_length
=
(
seq_length
+
lr_padding_
)
*
embedding_size
;
int64
data_start
=
(
offset_data
(
i
+
1
)
-
tokens_remaining
)
-
lr_padding_
;
int64
dense_start
=
data_start
*
embedding_size
;
int64
output_start
=
row_index
*
output_sequence_length
*
embedding_size
;
for
(
int
j
=
0
;
j
<
subseqence_length
;
++
j
)
{
output_data
[
output_start
+
j
]
=
dense_data
[
dense_start
+
j
];
}
indices
->
vec
<
int32
>
()(
row_index
)
=
i
;
VLOG
(
2
)
<<
"Rebatched "
<<
i
<<
" to "
<<
row_index
;
++
row_index
;
tokens_remaining
-=
sequence_length_
;
VLOG
(
2
)
<<
"Remaining: "
<<
tokens_remaining
;
}
}
for
(
int
j
=
0
;
j
<
num_elements
;
++
j
)
{
VLOG
(
2
)
<<
"Rebatch item :"
<<
j
<<
" has index: "
<<
indices
->
vec
<
int32
>
()(
j
);
}
}
private:
int
sequence_length_
;
int
lr_padding_
;
TF_DISALLOW_COPY_AND_ASSIGN
(
RebatchDensor
);
};
REGISTER_KERNEL_BUILDER
(
Name
(
"RebatchDensor"
).
Device
(
DEVICE_CPU
),
RebatchDensor
);
// Rebatches a dense ragged tensor into a batch of padded subsequences.
class
UnbatchSubsequences
:
public
OpKernel
{
public:
explicit
UnbatchSubsequences
(
OpKernelConstruction
*
context
)
:
OpKernel
(
context
)
{
OP_REQUIRES_OK
(
context
,
context
->
MatchSignature
(
{
DT_FLOAT
,
DT_INT32
,
DT_INT32
},
{
DT_FLOAT
}));
}
void
Compute
(
OpKernelContext
*
context
)
override
{
// Figure out how many sequences we need.
const
Tensor
&
data
=
context
->
input
(
0
);
const
int
input_batch_size
=
data
.
shape
().
dim_size
(
0
);
const
int
sequence_length
=
data
.
shape
().
dim_size
(
2
);
const
int
embedding_size
=
data
.
shape
().
dim_size
(
3
);
const
int
input_size
=
data
.
NumElements
();
const
Tensor
&
indices
=
context
->
input
(
1
);
const
int
indices_size
=
indices
.
shape
().
dim_size
(
0
);
const
Tensor
&
offsets
=
context
->
input
(
2
);
const
int
offsets_size
=
offsets
.
shape
().
dim_size
(
0
);
const
int
batch_size
=
offsets_size
-
1
;
const
auto
&
offset_data
=
offsets
.
vec
<
int32
>
();
int
max_sequence_size
=
0
;
for
(
int
i
=
0
;
i
<
batch_size
;
++
i
)
{
int
element_length
=
offset_data
(
i
+
1
)
-
offset_data
(
i
);
if
(
element_length
>
max_sequence_size
)
{
max_sequence_size
=
element_length
;
}
}
// Allocate the output tensors.
Tensor
*
output
;
VLOG
(
2
)
<<
"Unbatch shape: "
<<
batch_size
<<
" "
<<
max_sequence_size
<<
" "
<<
embedding_size
;
OP_REQUIRES_OK
(
context
,
context
->
allocate_output
(
0
,
TensorShape
({
batch_size
,
max_sequence_size
,
embedding_size
}),
&
output
));
output
->
flat
<
float
>
().
setZero
();
int
output_size
=
output
->
NumElements
();
const
float
*
input_data
=
data
.
flat
<
float
>
().
data
();
float
*
output_data
=
output
->
flat
<
float
>
().
data
();
const
int32
*
index_data
=
indices
.
flat
<
int32
>
().
data
();
int
previous_index
=
-
1
;
int
current_sequence_element
=
0
;
VLOG
(
2
)
<<
"Sequence length: "
<<
sequence_length
;
VLOG
(
2
)
<<
"Indices size: "
<<
indices_size
;
for
(
int
i
=
0
;
i
<
indices_size
;
++
i
)
{
int
current_index
=
index_data
[
i
];
CHECK
(
current_index
<
input_batch_size
)
<<
"Index out of bounds."
;
if
(
current_index
>
previous_index
)
{
previous_index
=
current_index
;
current_sequence_element
=
0
;
}
int
current_sequence_length
=
std
::
min
(
sequence_length
,
max_sequence_size
-
current_sequence_element
);
int64
input_offset
=
i
*
sequence_length
*
embedding_size
;
int64
output_offset
=
(
current_index
*
max_sequence_size
+
current_sequence_element
)
*
embedding_size
;
VLOG
(
2
)
<<
"cur_ind: "
<<
current_index
<<
" cur_element: "
<<
current_sequence_element
<<
" cur sqlen: "
<<
current_sequence_length
<<
" in: "
<<
input_offset
<<
" out: "
<<
output_offset
;
for
(
int
j
=
0
;
j
<
current_sequence_length
*
embedding_size
;
++
j
)
{
CHECK
((
output_offset
+
j
)
<
output_size
)
<<
"output index invalid"
;
CHECK
((
input_offset
+
j
)
<
input_size
)
<<
"input index invalid"
;
output_data
[
output_offset
+
j
]
=
input_data
[
input_offset
+
j
];
}
current_sequence_element
+=
current_sequence_length
;
}
}
private:
TF_DISALLOW_COPY_AND_ASSIGN
(
UnbatchSubsequences
);
};
REGISTER_KERNEL_BUILDER
(
Name
(
"UnbatchSubsequences"
).
Device
(
DEVICE_CPU
),
UnbatchSubsequences
);
/*******************************************************************************
* ComputeSessionOps below here.
******************************************************************************/
...
...
@@ -450,8 +654,8 @@ class ExtractFixedFeatures : public ComputeSessionOp {
component_name
(),
indices_allocator
,
ids_allocator
,
weights_allocator
,
channel_id_
);
VLOG
(
2
)
<<
"Extracted features ("
<<
num_features
<<
"): "
<<
" ids="
<<
context
->
mutable_output
(
1
)
->
vec
<
int64
>
()
<<
" weights="
<<
context
->
mutable_output
(
2
)
->
vec
<
float
>
()
<<
" ids="
<<
context
->
mutable_output
(
1
)
->
vec
<
int64
>
()
<<
" weights="
<<
context
->
mutable_output
(
2
)
->
vec
<
float
>
()
<<
" indices="
<<
context
->
mutable_output
(
0
)
->
vec
<
int32
>
();
}
...
...
@@ -546,7 +750,8 @@ REGISTER_KERNEL_BUILDER(Name("ExtractLinkFeatures").Device(DEVICE_CPU),
// Given a handle to a BatchedBeamComponentState, emits a vector of gold
// labels.
// The vector of gold labels has size batch_size * beam_size.
// The vector of gold labels has size batch_size * beam_size. The code assumes
// one label per instance.
class
EmitOracleLabels
:
public
ComputeSessionOp
{
public:
explicit
EmitOracleLabels
(
OpKernelConstruction
*
context
)
...
...
@@ -567,12 +772,13 @@ class EmitOracleLabels : public ComputeSessionOp {
TensorShape
({
session
->
BatchSize
(
component_name
())
*
session
->
BeamSize
(
component_name
())}),
&
output
));
std
::
vector
<
std
::
vector
<
int
>>
batched_labels
=
std
::
vector
<
std
::
vector
<
std
::
vector
<
Label
>
>>
batched_labels
=
session
->
EmitOracleLabels
(
component_name
());
int
raw_index
=
0
;
for
(
const
auto
&
batch_vector
:
batched_labels
)
{
for
(
const
auto
&
label
:
batch_vector
)
{
output
->
vec
<
int32
>
()(
raw_index
)
=
label
;
for
(
const
auto
&
instance_labels
:
batch_vector
)
{
// The code assumes there is one label per instance.
output
->
vec
<
int32
>
()(
raw_index
)
=
instance_labels
.
at
(
0
).
id
;
++
raw_index
;
}
}
...
...
@@ -585,6 +791,66 @@ class EmitOracleLabels : public ComputeSessionOp {
REGISTER_KERNEL_BUILDER
(
Name
(
"EmitOracleLabels"
).
Device
(
DEVICE_CPU
),
EmitOracleLabels
);
// Given a handle to a BatchedBeamComponentState, emits corresponding vectors of
// indices, gold labels, and probabilities. The size of the output vectors is
// equal to the sum of the number of labels for each instance in the beams in
// the batch.
class
EmitOracleLabelsAndProbabilities
:
public
ComputeSessionOp
{
public:
explicit
EmitOracleLabelsAndProbabilities
(
OpKernelConstruction
*
context
)
:
ComputeSessionOp
(
context
)
{
OP_REQUIRES_OK
(
context
,
context
->
MatchSignature
(
{
DT_STRING
},
{
DT_INT32
,
DT_INT32
,
DT_FLOAT
}));
}
bool
OutputsHandle
()
const
override
{
return
false
;
}
bool
RequiresComponentName
()
const
override
{
return
true
;
}
void
ComputeWithState
(
OpKernelContext
*
context
,
ComputeSession
*
session
)
override
{
const
std
::
vector
<
std
::
vector
<
std
::
vector
<
Label
>>>
batched_labels
=
session
->
EmitOracleLabels
(
component_name
());
int
label_count
=
0
;
for
(
const
auto
&
beam
:
batched_labels
)
{
for
(
const
auto
&
instance
:
beam
)
{
label_count
+=
instance
.
size
();
}
}
Tensor
*
indices_output
;
OP_REQUIRES_OK
(
context
,
context
->
allocate_output
(
0
,
TensorShape
({
label_count
}),
&
indices_output
));
Tensor
*
label_output
;
OP_REQUIRES_OK
(
context
,
context
->
allocate_output
(
1
,
TensorShape
({
label_count
}),
&
label_output
));
Tensor
*
prob_output
;
OP_REQUIRES_OK
(
context
,
context
->
allocate_output
(
2
,
TensorShape
({
label_count
}),
&
prob_output
));
// Index keeping track of each instance in the beams in the batch.
int
instance_index
=
-
1
;
int
raw_index
=
-
1
;
for
(
const
auto
&
beam
:
batched_labels
)
{
for
(
const
auto
&
instance
:
beam
)
{
++
instance_index
;
for
(
const
Label
&
label
:
instance
)
{
++
raw_index
;
indices_output
->
vec
<
int32
>
()(
raw_index
)
=
instance_index
;
label_output
->
vec
<
int32
>
()(
raw_index
)
=
label
.
id
;
prob_output
->
vec
<
float
>
()(
raw_index
)
=
label
.
probability
;
}
}
}
}
private:
TF_DISALLOW_COPY_AND_ASSIGN
(
EmitOracleLabelsAndProbabilities
);
};
REGISTER_KERNEL_BUILDER
(
Name
(
"EmitOracleLabelsAndProbabilities"
).
Device
(
DEVICE_CPU
),
EmitOracleLabelsAndProbabilities
);
// Given a handle to a ComponentState, emits a single bool indicating
// whether all elements in the batch contain beams containing all final states.
class
EmitAllFinal
:
public
ComputeSessionOp
{
...
...
research/syntaxnet/dragnn/core/ops/dragnn_op_kernels_test.cc
View file @
80178fc6
...
...
@@ -23,6 +23,7 @@
#include "dragnn/core/resource_container.h"
#include "dragnn/core/test/generic.h"
#include "dragnn/core/test/mock_compute_session.h"
#include "dragnn/core/util/label.h"
#include <gmock/gmock.h>
...
...
@@ -44,26 +45,26 @@ namespace syntaxnet {
namespace
dragnn
{
using
tensorflow
::
AllocatorAttributes
;
using
tensorflow
::
checkpoint
::
TensorSliceReaderCacheWrapper
;
using
tensorflow
::
DT_BOOL
;
using
tensorflow
::
DT_FLOAT
;
using
tensorflow
::
DT_STRING
;
using
tensorflow
::
DT_INT32
;
using
tensorflow
::
FrameAndIter
;
using
tensorflow
::
DT_STRING
;
using
tensorflow
::
DataType
;
using
tensorflow
::
FrameAndIter
;
using
tensorflow
::
NodeDefBuilder
;
using
tensorflow
::
OpKernelContext
;
using
tensorflow
::
ResourceMgr
;
using
tensorflow
::
ScopedStepContainer
;
using
tensorflow
::
Status
;
using
tensorflow
::
test
::
SetOutputAttrs
;
using
tensorflow
::
TensorShape
;
using
tensorflow
::
checkpoint
::
TensorSliceReaderCacheWrapper
;
using
tensorflow
::
test
::
SetOutputAttrs
;
using
testing
::
_
;
using
testing
::
ElementsAreArray
;
using
testing
::
Invoke
;
using
testing
::
Pointwise
;
using
testing
::
Return
;
using
testing
::
_
;
typedef
ResourceContainer
<
ComputeSession
>
ComputeSessionResource
;
typedef
ResourceContainer
<
ComputeSessionPool
>
ComputeSessionPoolResource
;
...
...
@@ -126,12 +127,18 @@ class TestComponent : public Component {
int
batch_size_padding
,
int
num_steps_padding
,
int
output_array_size
,
const
vector
<
const
float
*>
&
per_channel_embeddings
,
float
*
embedding_matrix
)
override
{}
void
BulkEmbedDenseFixedFeatures
(
const
vector
<
const
float
*>
&
per_channel_embeddings
,
float
*
embedding_output
,
int
embedding_output_size
,
int
*
offset_array_output
,
int
offset_array_size
)
override
{}
int
BulkDenseFeatureSize
()
const
override
{
return
0
;
}
std
::
vector
<
LinkFeatures
>
GetRawLinkFeatures
(
int
channel_id
)
const
override
{
std
::
vector
<
LinkFeatures
>
ret
;
return
ret
;
}
std
::
vector
<
std
::
vector
<
int
>>
GetOracleLabels
()
const
override
{
std
::
vector
<
std
::
vector
<
int
>>
ret
;
std
::
vector
<
std
::
vector
<
std
::
vector
<
Label
>>>
GetOracleLabels
()
const
override
{
std
::
vector
<
std
::
vector
<
std
::
vector
<
Label
>>>
ret
;
return
ret
;
}
void
FinalizeData
()
override
{}
...
...
@@ -482,6 +489,201 @@ TEST_F(DragnnOpKernelsTest, GetSessionCountsOpTest) {
GetOutput
(
0
)
->
vec
<
int64
>
()(
1
));
}
// The RebatchDensor op should rebatch densors.
TEST_F
(
DragnnOpKernelsTest
,
RebatchDensorOpTest
)
{
int
sequence_length
=
3
;
int
pad_length
=
2
;
TF_ASSERT_OK
(
NodeDefBuilder
(
"rebatch_densor"
,
"RebatchDensor"
)
.
Attr
(
"sequence_length"
,
sequence_length
)
.
Attr
(
"lr_padding"
,
pad_length
)
.
Input
(
FakeInput
(
DT_FLOAT
))
// The dense data tensor.
.
Input
(
FakeInput
(
DT_INT32
))
// The offsets tensor.
.
Finalize
(
node_def
()));
TF_ASSERT_OK
(
InitOp
());
// Set the input data.
const
std
::
vector
<
float
>
weights
=
{
// PASSAGE 1
1.01
,
1.02
,
//
1.04
,
1.05
,
//
1.07
,
1.08
,
//
1.10
,
1.11
,
//
// PASSAGE 2
2.01
,
2.02
,
//
2.03
,
2.04
,
//
2.05
,
2.06
,
//
2.07
,
2.08
,
//
2.09
,
2.10
,
//
2.11
,
2.12
//
};
AddInputFromArray
<
float
>
(
TensorShape
({
10
,
2
}),
weights
);
const
std
::
vector
<
int
>
offsets
=
{
0
,
4
,
10
};
AddInputFromArray
<
int
>
(
TensorShape
({
3
}),
offsets
);
// Reset the test context to ensure it's clean.
ResetOpKernelContext
();
// Run the kernel.
TF_EXPECT_OK
(
RunOpKernelWithContext
());
// The first two embeddings in the 1st and 3rd output should be {0.0}
// The first two embeddings in the 2nd output should be embeddings from token
// 1 and 2 (so vector items 4 through 10).
// The last 2 embeddings in row 1 should be from token 4, then 0s.
// The last 4 embeddings in rows 2 and 3 should be 0.
const
std
::
vector
<
float
>
expected_weights
=
{
// BATCH 0
0.0
,
0.0
,
//
0.0
,
0.0
,
//
1.01
,
1.02
,
//
1.04
,
1.05
,
//
1.07
,
1.08
,
//
1.10
,
1.11
,
//
0.0
,
0.0
,
//
// BATCH 1
1.04
,
1.05
,
//
1.07
,
1.08
,
//
1.10
,
1.11
,
//
0.0
,
0.0
,
//
0.0
,
0.0
,
//
0.0
,
0.0
,
//
0.0
,
0.0
,
//
// BATCH 2
0.0
,
0.0
,
//
0.0
,
0.0
,
//
2.01
,
2.02
,
//
2.03
,
2.04
,
//
2.05
,
2.06
,
//
2.07
,
2.08
,
//
2.09
,
2.10
,
//
// BATCH 3
2.03
,
2.04
,
//
2.05
,
2.06
,
//
2.07
,
2.08
,
//
2.09
,
2.10
,
//
2.11
,
2.12
,
//
0.0
,
0.0
,
//
0.0
,
0.0
,
//
};
for
(
int
i
=
0
;
i
<
expected_weights
.
size
();
++
i
)
{
LOG
(
INFO
)
<<
GetOutput
(
0
)
->
flat
<
float
>
()(
i
);
}
// The output should have dimensions {4, 7, 2}.
EXPECT_EQ
(
4
,
GetOutput
(
0
)
->
dim_size
(
0
));
EXPECT_EQ
(
7
,
GetOutput
(
0
)
->
dim_size
(
1
));
EXPECT_EQ
(
2
,
GetOutput
(
0
)
->
dim_size
(
2
));
// The output should match the expected tensor.
for
(
int
i
=
0
;
i
<
expected_weights
.
size
();
++
i
)
{
EXPECT_EQ
(
expected_weights
[
i
],
GetOutput
(
0
)
->
flat
<
float
>
()(
i
))
<<
"Failed at index "
<<
i
;
}
// The offsets output shout have dimension {3}.
EXPECT_EQ
(
4
,
GetOutput
(
1
)
->
dim_size
(
0
));
std
::
vector
<
int
>
expected_indices
=
{
0
,
0
,
1
,
1
};
for
(
int
i
=
0
;
i
<
expected_indices
.
size
();
++
i
)
{
EXPECT_EQ
(
expected_indices
[
i
],
GetOutput
(
1
)
->
flat
<
int32
>
()(
i
))
<<
"Failed at index "
<<
i
;
}
}
// Todo(me): write this
TEST_F
(
DragnnOpKernelsTest
,
UnbatchSubsequences
)
{
TF_ASSERT_OK
(
NodeDefBuilder
(
"unbatch_subsequences"
,
"UnbatchSubsequences"
)
.
Input
(
FakeInput
(
DT_FLOAT
))
// The data tensor.
.
Input
(
FakeInput
(
DT_INT32
))
// The index tensor.
.
Input
(
FakeInput
(
DT_INT32
))
// The offsets tensor.
.
Finalize
(
node_def
()));
TF_ASSERT_OK
(
InitOp
());
// Set the input data.
const
std
::
vector
<
float
>
input
=
{
// BATCH 0
1.01
,
1.02
,
//
1.04
,
1.05
,
//
1.07
,
1.08
,
//
1.10
,
1.11
,
//
1.12
,
1.13
,
//
// BATCH 1
1.14
,
1.15
,
//
0.0
,
0.0
,
//
0.0
,
0.0
,
//
0.0
,
0.0
,
//
0.0
,
0.0
,
//
// BATCH 2
2.01
,
2.02
,
//
0.0
,
0.0
,
//
0.0
,
0.0
,
//
0.0
,
0.0
,
//
0.0
,
0.0
,
//
// BATCH 3
3.01
,
3.02
,
//
0.0
,
0.0
,
//
0.0
,
0.0
,
//
0.0
,
0.0
,
//
0.0
,
0.0
//
};
AddInputFromArray
<
float
>
(
TensorShape
({
4
,
1
,
5
,
2
}),
input
);
const
std
::
vector
<
int
>
indices
=
{
0
,
0
,
1
,
2
};
AddInputFromArray
<
int
>
(
TensorShape
({
4
}),
indices
);
const
std
::
vector
<
int
>
offsets
=
{
0
,
6
,
7
,
8
};
AddInputFromArray
<
int
>
(
TensorShape
({
4
}),
offsets
);
// Reset the test context to ensure it's clean.
ResetOpKernelContext
();
// Run the kernel.
TF_EXPECT_OK
(
RunOpKernelWithContext
());
// The first two embeddings in the 1st and 3rd output should be {0.0}
// The first two embeddings in the 2nd output should be embeddings from token
// 1 and 2 (so vector items 4 through 10).
// The last 2 embeddings in row 1 should be from token 4, then 0s.
// The last 4 embeddings in rows 2 and 3 should be 0.
const
std
::
vector
<
float
>
expected_weights
=
{
// BATCH 0
1.01
,
1.02
,
//
1.04
,
1.05
,
//
1.07
,
1.08
,
//
1.10
,
1.11
,
//
1.12
,
1.13
,
//
1.14
,
1.15
,
//
// BATCH 1
2.01
,
2.02
,
//
0.0
,
0.0
,
//
0.0
,
0.0
,
//
0.0
,
0.0
,
//
0.0
,
0.0
,
//
0.0
,
0.0
,
//
// BATCH 2
3.01
,
3.02
,
//
0.0
,
0.0
,
//
0.0
,
0.0
,
//
0.0
,
0.0
,
//
0.0
,
0.0
,
//
0.0
,
0.0
//
};
for
(
int
i
=
0
;
i
<
expected_weights
.
size
();
++
i
)
{
LOG
(
INFO
)
<<
GetOutput
(
0
)
->
flat
<
float
>
()(
i
);
}
// The output should have dimensions {3, 7, 2}.
EXPECT_EQ
(
3
,
GetOutput
(
0
)
->
dim_size
(
0
));
EXPECT_EQ
(
6
,
GetOutput
(
0
)
->
dim_size
(
1
));
EXPECT_EQ
(
2
,
GetOutput
(
0
)
->
dim_size
(
2
));
// The output should match the expected tensor.
for
(
int
i
=
0
;
i
<
expected_weights
.
size
();
++
i
)
{
EXPECT_EQ
(
expected_weights
[
i
],
GetOutput
(
0
)
->
flat
<
float
>
()(
i
))
<<
"Failed at index "
<<
i
;
}
}
// The AdvanceFromOracle op should call AdvanceFromOracle on the specified
// component name.
TEST_F
(
DragnnOpKernelsTest
,
AdvanceFromOracleOpTest
)
{
...
...
@@ -651,7 +853,8 @@ TEST_F(DragnnOpKernelsTest, ExtractFixedFeaturesOpTest) {
// If we have 3 features, for a given channel, we might have:
// feature a: (5, 1)
// feature b: (5, 0.5), (6, 0.7)
// feature c: (3, 0.1), (7, [empty]) <- Empty weights are equivalent to 1.0.
// feature c: (3, 0.1), (7, [empty]) <- Empty weights are equivalent
// to 1.0.
// In this case:
// indices should look like [0 , 1 , 1 , 2 , 2 ]
// ids should be [5 , 5 , 6 , 3 , 7 ]
...
...
@@ -727,15 +930,15 @@ TEST_F(DragnnOpKernelsTest, ExtractLinkFeaturesOpTest) {
MockComputeSession
*
mock_session_ptr
=
mock_session
.
get
();
// This op will return link features in two flat arrays using batch-major
// ordering. So, if we have a batch of 2 and a beam of 3, with data as
follows
// (note that the features are {batch,beam,step} and [] is 'empty')
// ordering. So, if we have a batch of 2 and a beam of 3, with data as
//
follows
(note that the features are {batch,beam,step} and [] is 'empty')
// batch 1 features: {{02,03,[]},{01,00,04},{08,06,01}}
// batch 2 features: {{12,13,14},{11,12,-1},{18,16,20}}
//
// and a **source component** beam size of 5 should result in output
tensors:
// step_idx (tensor 0): {-1, 4, 1, 14, -1, 20}
//
array_idx (tensor
1): { 0, 5, 46, 73, 0, 106}
//
(0
[step=-1]),(5=1*5+0),(46=8*5+6),(73=12*5+13),(0 [step=-1]),(96=18*5+16)
// and a **source component** beam size of 5 should result in output
//
tensors:
step_idx (tensor 0): {-1, 4, 1, 14, -1, 20}
array_idx (tensor
// 1): { 0, 5, 46, 73, 0, 106}
(0
// [step=-1]),(5=1*5+0),(46=8*5+6),(73=12*5+13),(0 [step=-1]),(96=18*5+16)
constexpr
int
kSourceComponentBeamSize
=
5
;
std
::
vector
<
LinkFeatures
>
features
;
...
...
@@ -814,8 +1017,11 @@ TEST_F(DragnnOpKernelsTest, EmitOracleLabelsOpTest) {
constexpr
int
kBatchSize
=
2
;
constexpr
int
kBeamSize
=
4
;
const
std
::
vector
<
std
::
vector
<
int
>>
oracle_labels
(
{{
1
,
3
,
5
,
7
},
{
2
,
4
,
6
,
8
}});
// Vectors containing, respectively, label ids and the corresponding Labels.
const
std
::
vector
<
std
::
vector
<
std
::
vector
<
Label
>>>
oracle_labels
(
{{{{
1
,
1.
f
}},
{{
3
,
1.
f
}},
{{
5
,
1.
f
}},
{{
7
,
1.
f
}}},
{{{
2
,
1.
f
}},
{{
4
,
1.
f
}},
{{
6
,
1.
f
}},
{{
8
,
1.
f
}}}});
EXPECT_CALL
(
*
mock_session_ptr
,
BatchSize
(
component_name
))
.
WillRepeatedly
(
Return
(
kBatchSize
));
...
...
@@ -836,6 +1042,73 @@ TEST_F(DragnnOpKernelsTest, EmitOracleLabelsOpTest) {
}
}
// The EmitOracleLabelsAndProbabilities op returns vectors of instance
// indices, labels, and probabilities corresponding to the elements in the
// beams in the batch.
TEST_F
(
DragnnOpKernelsTest
,
EmitOracleLabelsAndProbabilitiesOpTest
)
{
// Create and initialize the kernel under test.
const
string
component_name
=
"TESTING_COMPONENT_NAME"
;
TF_ASSERT_OK
(
NodeDefBuilder
(
"emit_oracle_labels_and_probabilities"
,
"EmitOracleLabelsAndProbabilities"
)
.
Attr
(
"component"
,
component_name
)
.
Input
(
FakeInput
(
DT_STRING
))
// The handle for the ComputeSession.
.
Finalize
(
node_def
()));
TF_ASSERT_OK
(
InitOp
());
// Set the input data.
const
string
container_string
=
"container_str"
;
const
string
id_string
=
"id_str"
;
AddInputFromList
<
string
>
(
TensorShape
({
2
}),
{
container_string
,
id_string
});
// Reset the test context to ensure it's clean.
ResetOpKernelContext
();
// Create a MockComputeSession and set expectations.
std
::
unique_ptr
<
MockComputeSession
>
mock_session
(
new
MockComputeSession
());
MockComputeSession
*
mock_session_ptr
=
mock_session
.
get
();
// Wrap the ComputeSessionResource and put it into the resource manager.
TF_ASSERT_OK
(
resource_mgr
()
->
Create
<
ComputeSessionResource
>
(
container_string
,
id_string
,
new
ComputeSessionResource
(
std
::
move
(
mock_session
))));
// The op should request the oracle labels, and probabilities. They should
// be returned in batch major order, so if the label:probability pairs are:
// batch 1 oracle labels: {{1:0.6, 2:0.8}, {3:1.0}, {5:0.7}}
// batch 2 oracle labels: {{2:0.9}, {4:1.0}, {6:0.3, 8:0.6}}
// then the resulting output tensors are:
// indices_output: {0, 0, 1, 2, 3, 4, 5, 5}
// label_output: {1, 2, 3, 5, 2, 4, 6, 8}
// prob_output: {0.6, 0.8, 1.0, 0.7, 0.9, 1.0, 0.3, 0.6}
// Oracle labels along with their probabilities.
const
std
::
vector
<
std
::
vector
<
std
::
vector
<
Label
>>>
oracle_labels
(
{{{{
1
,
0.6
},
{
2
,
0.8
}},
{{
3
,
1.0
}},
{{
5
,
0.7
}}},
{{{
2
,
0.9
}},
{{
4
,
1.0
}},
{{
6
,
0.3
},
{
8
,
0.6
}}}});
EXPECT_CALL
(
*
mock_session_ptr
,
EmitOracleLabels
(
component_name
))
.
WillOnce
(
Return
(
oracle_labels
));
const
std
::
vector
<
int
>
expected_indices
({
0
,
0
,
1
,
2
,
3
,
4
,
5
,
5
});
const
std
::
vector
<
int
>
expected_labels
({
1
,
2
,
3
,
5
,
2
,
4
,
6
,
8
});
const
std
::
vector
<
float
>
expected_probs
(
{
0.6
,
0.8
,
1.0
,
0.7
,
0.9
,
1.0
,
0.3
,
0.6
});
// Run the kernel.
TF_EXPECT_OK
(
RunOpKernelWithContext
());
// Validate the outputs.
EXPECT_EQ
(
expected_indices
.
size
(),
GetOutput
(
0
)
->
NumElements
());
EXPECT_EQ
(
expected_labels
.
size
(),
GetOutput
(
1
)
->
NumElements
());
EXPECT_EQ
(
expected_probs
.
size
(),
GetOutput
(
2
)
->
NumElements
());
for
(
int
i
=
0
;
i
<
expected_indices
.
size
();
++
i
)
{
EXPECT_EQ
(
expected_indices
[
i
],
GetOutput
(
0
)
->
vec
<
int32
>
()(
i
));
EXPECT_EQ
(
expected_labels
[
i
],
GetOutput
(
1
)
->
vec
<
int32
>
()(
i
));
EXPECT_EQ
(
expected_probs
[
i
],
GetOutput
(
2
)
->
vec
<
float
>
()(
i
));
}
}
// The EmitAllFinal op should return the result of IsTerminal(component_name).
TEST_F
(
DragnnOpKernelsTest
,
EmitAllFinalOpTest
)
{
// Create and initialize the kernel under test.
...
...
research/syntaxnet/dragnn/core/ops/dragnn_ops.cc
View file @
80178fc6
...
...
@@ -13,7 +13,9 @@
// limitations under the License.
// =============================================================================
#include "dragnn/core/ops/shape_helpers.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/shape_inference.h"
namespace
syntaxnet
{
namespace
dragnn
{
...
...
@@ -22,6 +24,10 @@ REGISTER_OP("SetAssetDirectory")
.
Input
(
"asset_directory: string"
)
.
Output
(
"asset_directory_out: string"
)
.
SetIsStateful
()
.
SetShapeFn
([](
tensorflow
::
shape_inference
::
InferenceContext
*
context
)
{
context
->
set_output
(
0
,
context
->
Vector
(
1
));
return
ScalarInputShape
(
0
,
context
);
})
.
Doc
(
R"doc(
Override the paths to assets specified in the MasterSpec with the given
asset_directory. This op must be called before any calls to GetSession, as it
...
...
@@ -38,6 +44,10 @@ REGISTER_OP("GetSession")
.
Attr
(
"grid_point: string"
)
.
Output
(
"handle: string"
)
.
SetIsStateful
()
.
SetShapeFn
([](
tensorflow
::
shape_inference
::
InferenceContext
*
context
)
{
TF_RETURN_IF_ERROR
(
ScalarInputShape
(
0
,
context
));
return
ComputeSessionHandleOutputShape
(
context
);
})
.
Doc
(
R"doc(
Given MasterSpec and GridPoint protos, outputs a handle to a ComputeSession.
...
...
@@ -48,7 +58,11 @@ grid_point: A serialized syntaxnet.dragnn.GridPoint proto.
handle: A string handle to a ComputeSession.
)doc"
);
REGISTER_OP
(
"ReleaseSession"
).
Input
(
"handle: string"
).
SetIsStateful
().
Doc
(
R"doc(
REGISTER_OP
(
"ReleaseSession"
)
.
Input
(
"handle: string"
)
.
SetIsStateful
()
.
SetShapeFn
(
ComputeSessionHandleInputShape
)
.
Doc
(
R"doc(
Given a ComputeSession, return it to the ComputeSession pool.
This ComputeSession will no longer be available after this op returns.
...
...
@@ -60,6 +74,10 @@ REGISTER_OP("GetSessionCounts")
.
Input
(
"container: string"
)
.
Output
(
"stats: int64"
)
.
SetIsStateful
()
.
SetShapeFn
([](
tensorflow
::
shape_inference
::
InferenceContext
*
context
)
{
context
->
set_output
(
0
,
context
->
Vector
(
2
));
return
ScalarInputShape
(
0
,
context
);
})
.
Doc
(
R"doc(
Given a container string, output session counts for that ComputeSessionPool.
...
...
@@ -68,11 +86,70 @@ stats: A vector of stats. [0] is the total number of created sessions. [1] is
the number of sessions that are currently not in the pool.
)doc"
);
REGISTER_OP
(
"RebatchDensor"
)
.
Input
(
"dense_data: float"
)
.
Input
(
"offsets: int32"
)
.
Attr
(
"sequence_length: int"
)
.
Attr
(
"lr_padding: int"
)
.
Output
(
"rebatched_data: float"
)
.
Output
(
"rebatched_indices: int32"
)
.
SetIsStateful
()
.
SetShapeFn
([](
tensorflow
::
shape_inference
::
InferenceContext
*
context
)
{
int
sequence_length
;
TF_RETURN_IF_ERROR
(
context
->
GetAttr
(
"sequence_length"
,
&
sequence_length
));
int
lr_padding
;
TF_RETURN_IF_ERROR
(
context
->
GetAttr
(
"lr_padding"
,
&
lr_padding
));
const
int
output_sequence_length
=
2
*
lr_padding
+
sequence_length
;
TF_RETURN_IF_ERROR
(
MatrixInputShape
(
0
,
context
));
const
auto
embedding_dim
=
context
->
Dim
(
context
->
input
(
0
),
1
);
context
->
set_output
(
0
,
context
->
MakeShape
({
context
->
UnknownDim
(),
output_sequence_length
,
embedding_dim
}));
VectorOutputShape
(
1
,
context
);
return
VectorInputShape
(
1
,
context
);
})
.
Doc
(
R"doc(
Rebatch a dense ragged tensor into a set of fixed-size subsequences.
dense_data: A tensor containing the dense ragged data.
offsets: The passage offsets into the dense_data tensor.
sequence_length: The size of the sequence length to rebatch to.
lr_padding: The amount of context to pad when breaking a passage.
)doc"
);
REGISTER_OP
(
"UnbatchSubsequences"
)
.
Input
(
"data: float"
)
.
Input
(
"indices: int32"
)
.
Input
(
"offsets: int32"
)
.
Output
(
"rebatched_data: float"
)
.
SetIsStateful
()
.
SetShapeFn
([](
tensorflow
::
shape_inference
::
InferenceContext
*
context
)
{
TF_RETURN_IF_ERROR
(
TensorInputShape
(
0
,
4
,
context
));
const
auto
embedding_dim
=
context
->
Dim
(
context
->
input
(
0
),
3
);
context
->
set_output
(
0
,
context
->
MakeShape
({
context
->
UnknownDim
(),
context
->
UnknownDim
(),
embedding_dim
}));
TF_RETURN_IF_ERROR
(
VectorInputShape
(
1
,
context
));
return
VectorInputShape
(
2
,
context
);
})
.
Doc
(
R"doc(
Rebatch a dense ragged tensor into a set of fixed-size subsequences.
data: A tensor containing the fixed-length subsequences to unbatch.
indices: A tensor mapping the subsequences to the original sequences.
offsets: The passage offsets used to create the subsequences.
)doc"
);
REGISTER_OP
(
"InitComponentData"
)
.
Input
(
"handle: string"
)
.
Input
(
"beam_size: int32"
)
.
Attr
(
"component: string"
)
.
Output
(
"output_handle: string"
)
.
SetShapeFn
([](
tensorflow
::
shape_inference
::
InferenceContext
*
context
)
{
TF_RETURN_IF_ERROR
(
ScalarInputShape
(
1
,
context
));
return
ComputeSessionHandleInputAndOutputShape
(
context
);
})
.
Doc
(
R"doc(
Initialize a component with the given beam size for a given ComputeSession.
...
...
@@ -86,6 +163,10 @@ REGISTER_OP("BatchSize")
.
Input
(
"handle: string"
)
.
Attr
(
"component: string"
)
.
Output
(
"batch_size: int32"
)
.
SetShapeFn
([](
tensorflow
::
shape_inference
::
InferenceContext
*
context
)
{
ScalarOutputShape
(
0
,
context
);
return
ComputeSessionHandleInputShape
(
context
);
})
.
Doc
(
R"doc(
Given a ComputeSession and a component name,return the component batch size.
...
...
@@ -99,6 +180,10 @@ REGISTER_OP("SetTracing")
.
Input
(
"tracing_on: bool"
)
.
Attr
(
"component: string = 'NOT_USED_FOR_THIS_OP'"
)
.
Output
(
"output_handle: string"
)
.
SetShapeFn
([](
tensorflow
::
shape_inference
::
InferenceContext
*
context
)
{
TF_RETURN_IF_ERROR
(
ScalarInputShape
(
1
,
context
));
return
ComputeSessionHandleInputAndOutputShape
(
context
);
})
.
Doc
(
R"doc(
Given a ComputeSession, turns on or off tracing for all components.
...
...
@@ -112,6 +197,10 @@ REGISTER_OP("AttachDataReader")
.
Input
(
"input_spec: string"
)
.
Attr
(
"component: string = 'NOT_USED_FOR_THIS_OP'"
)
.
Output
(
"output_handle: string"
)
.
SetShapeFn
([](
tensorflow
::
shape_inference
::
InferenceContext
*
context
)
{
TF_RETURN_IF_ERROR
(
VectorInputShape
(
1
,
context
));
return
ComputeSessionHandleInputAndOutputShape
(
context
);
})
.
Doc
(
R"doc(
Given a ComputeSession, attach a data source.
...
...
@@ -127,6 +216,7 @@ REGISTER_OP("AdvanceFromOracle")
.
Input
(
"handle: string"
)
.
Attr
(
"component: string"
)
.
Output
(
"output_handle: string"
)
.
SetShapeFn
(
ComputeSessionHandleInputAndOutputShape
)
.
Doc
(
R"doc(
Given a ComputeSession and a Component name, advance the component via oracle.
...
...
@@ -140,6 +230,10 @@ REGISTER_OP("AdvanceFromPrediction")
.
Input
(
"scores: float"
)
.
Attr
(
"component: string"
)
.
Output
(
"output_handle: string"
)
.
SetShapeFn
([](
tensorflow
::
shape_inference
::
InferenceContext
*
context
)
{
TF_RETURN_IF_ERROR
(
MatrixInputShape
(
1
,
context
));
return
ComputeSessionHandleInputAndOutputShape
(
context
);
})
.
Doc
(
R"doc(
Given a ComputeSession, a Component name, and a score tensor, advance the state.
...
...
@@ -156,6 +250,12 @@ REGISTER_OP("ExtractFixedFeatures")
.
Output
(
"weights: float"
)
.
Attr
(
"component: string"
)
.
Attr
(
"channel_id: int"
)
.
SetShapeFn
([](
tensorflow
::
shape_inference
::
InferenceContext
*
context
)
{
VectorOutputShape
(
0
,
context
);
VectorOutputShape
(
1
,
context
);
VectorOutputShape
(
2
,
context
);
return
ComputeSessionHandleInputShape
(
context
);
})
.
Doc
(
R"doc(
Given a ComputeSession, Component, and channel index, output fixed features.
...
...
@@ -179,6 +279,11 @@ REGISTER_OP("ExtractLinkFeatures")
.
Output
(
"idx: int32"
)
.
Attr
(
"component: string"
)
.
Attr
(
"channel_id: int"
)
.
SetShapeFn
([](
tensorflow
::
shape_inference
::
InferenceContext
*
context
)
{
VectorOutputShape
(
0
,
context
);
VectorOutputShape
(
1
,
context
);
return
ComputeSessionHandleInputShape
(
context
);
})
.
Doc
(
R"doc(
Given a ComputeSession, Component, and a channel index, outputs link features.
...
...
@@ -195,6 +300,10 @@ REGISTER_OP("EmitOracleLabels")
.
Input
(
"handle: string"
)
.
Output
(
"gold_labels: int32"
)
.
Attr
(
"component: string"
)
.
SetShapeFn
([](
tensorflow
::
shape_inference
::
InferenceContext
*
context
)
{
VectorOutputShape
(
0
,
context
);
return
ComputeSessionHandleInputShape
(
context
);
})
.
Doc
(
R"doc(
Given a ComputeSession and Component, emit a vector of gold labels.
...
...
@@ -204,10 +313,39 @@ gold_labels: A [batch_size * beam_size] vector of gold labels for the current
component: The name of a Component instance, matching the ComponentSpec.name.
)doc"
);
REGISTER_OP
(
"EmitOracleLabelsAndProbabilities"
)
.
Input
(
"handle: string"
)
.
Output
(
"instance_indices: int32"
)
.
Output
(
"gold_labels: int32"
)
.
Output
(
"probabilities: float"
)
.
Attr
(
"component: string"
)
.
SetShapeFn
([](
tensorflow
::
shape_inference
::
InferenceContext
*
context
)
{
VectorOutputShape
(
0
,
context
);
VectorOutputShape
(
1
,
context
);
VectorOutputShape
(
2
,
context
);
return
ComputeSessionHandleInputShape
(
context
);
})
.
Doc
(
R"doc(
Given a ComputeSession and Component, emit corresponding vectors of instance
indices, gold labels, and probabilities.
handle: A handle to a ComputeSession.
instance_indices: A vector [N] of indices for the current ComputeSession, where
N is the number of instance labels. Each element in each beam is
assigned an index.
gold_labels: A vector [N] of gold labels for the current ComputeSession.
probabilities: A vector [N] of probabilities for the current ComputeSession.
component: The name of a Component instance, matching the ComponentSpec.name.
)doc"
);
REGISTER_OP
(
"EmitAllFinal"
)
.
Input
(
"handle: string"
)
.
Output
(
"all_final: bool"
)
.
Attr
(
"component: string"
)
.
SetShapeFn
([](
tensorflow
::
shape_inference
::
InferenceContext
*
context
)
{
context
->
set_output
(
0
,
context
->
Vector
(
1
));
return
ComputeSessionHandleInputShape
(
context
);
})
.
Doc
(
R"doc(
Given a ComputeSession and Component, returns whether the Component is final.
...
...
@@ -223,6 +361,7 @@ REGISTER_OP("WriteAnnotations")
.
Input
(
"handle: string"
)
.
Output
(
"output_handle: string"
)
.
Attr
(
"component: string"
)
.
SetShapeFn
(
ComputeSessionHandleInputAndOutputShape
)
.
Doc
(
R"doc(
Given a ComputeSession, has the given component write out its annotations.
...
...
@@ -238,6 +377,10 @@ REGISTER_OP("EmitAnnotations")
.
Input
(
"handle: string"
)
.
Output
(
"annotations: string"
)
.
Attr
(
"component: string"
)
.
SetShapeFn
([](
tensorflow
::
shape_inference
::
InferenceContext
*
context
)
{
VectorOutputShape
(
0
,
context
);
return
ComputeSessionHandleInputShape
(
context
);
})
.
Doc
(
R"doc(
Given a ComputeSession, emits strings with final predictions for the model.
...
...
@@ -252,6 +395,10 @@ REGISTER_OP("GetComponentTrace")
.
Input
(
"handle: string"
)
.
Output
(
"trace: string"
)
.
Attr
(
"component: string"
)
.
SetShapeFn
([](
tensorflow
::
shape_inference
::
InferenceContext
*
context
)
{
VectorOutputShape
(
0
,
context
);
return
ComputeSessionHandleInputShape
(
context
);
})
.
Doc
(
R"doc(
Gets the raw MasterTrace proto for each batch, state, and beam slot.
...
...
research/syntaxnet/dragnn/core/ops/shape_helpers.h
0 → 100644
View file @
80178fc6
// Copyright 2018 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
// Shape inference functions for DRAGNN ops.
#ifndef DRAGNN_CORE_OPS_SHAPE_HELPERS_H_
#define DRAGNN_CORE_OPS_SHAPE_HELPERS_H_
#include "syntaxnet/ops/shape_helpers.h"
#include "tensorflow/core/framework/shape_inference.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
namespace
syntaxnet
{
namespace
dragnn
{
// Returns OK if the 0'th input of the |context| is compatible with the shape of
// a ComputeSession handle.
inline
tensorflow
::
Status
ComputeSessionHandleInputShape
(
tensorflow
::
shape_inference
::
InferenceContext
*
context
)
{
tensorflow
::
shape_inference
::
ShapeHandle
unused
;
return
context
->
Merge
(
context
->
input
(
0
),
context
->
Vector
(
2
),
&
unused
);
}
// Sets the 0'th output of the |context| to have the shape of a ComputeSession
// handle. Always returns OK.
inline
tensorflow
::
Status
ComputeSessionHandleOutputShape
(
tensorflow
::
shape_inference
::
InferenceContext
*
context
)
{
context
->
set_output
(
0
,
context
->
Vector
(
2
));
return
tensorflow
::
Status
::
OK
();
}
// For convenience, combines ComputeSessionHandle{Input,Output}Shape().
inline
tensorflow
::
Status
ComputeSessionHandleInputAndOutputShape
(
tensorflow
::
shape_inference
::
InferenceContext
*
context
)
{
TF_RETURN_IF_ERROR
(
ComputeSessionHandleInputShape
(
context
));
return
ComputeSessionHandleOutputShape
(
context
);
}
}
// namespace dragnn
}
// namespace syntaxnet
#endif // DRAGNN_CORE_OPS_SHAPE_HELPERS_H_
research/syntaxnet/dragnn/core/test/BUILD
View file @
80178fc6
...
...
@@ -12,8 +12,9 @@ cc_library(
"//dragnn/core:index_translator"
,
"//dragnn/core/interfaces:component"
,
"//dragnn/core/interfaces:transition_state"
,
"//dragnn/protos:data_proto"
,
"//dragnn/protos:spec_proto"
,
"//dragnn/core/util:label"
,
"//dragnn/protos:data_proto_cc"
,
"//dragnn/protos:spec_proto_cc"
,
"//syntaxnet:base"
,
"//syntaxnet:test_main"
,
],
...
...
@@ -27,8 +28,9 @@ cc_library(
"//dragnn/components/util:bulk_feature_extractor"
,
"//dragnn/core:compute_session"
,
"//dragnn/core:input_batch_cache"
,
"//dragnn/protos:data_proto"
,
"//dragnn/protos:spec_proto"
,
"//dragnn/core/util:label"
,
"//dragnn/protos:data_proto_cc"
,
"//dragnn/protos:spec_proto_cc"
,
"//syntaxnet:base"
,
"//syntaxnet:test_main"
,
],
...
...
@@ -45,6 +47,12 @@ cc_library(
],
)
cc_library
(
name
=
"fake_component_base"
,
hdrs
=
[
"fake_component_base.h"
],
deps
=
[
"//dragnn/core/interfaces:component"
],
)
cc_library
(
name
=
"generic"
,
testonly
=
True
,
...
...
research/syntaxnet/dragnn/core/test/fake_component_base.h
0 → 100644
View file @
80178fc6
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#ifndef DRAGNN_CORE_TEST_FAKE_COMPONENT_BASE_H_
#define DRAGNN_CORE_TEST_FAKE_COMPONENT_BASE_H_
#include "dragnn/core/interfaces/component.h"
#include "dragnn/protos/data.pb.h"
namespace
syntaxnet
{
namespace
dragnn
{
// Define a test component to validate registered construction.
class
FakeComponentBase
:
public
Component
{
public:
FakeComponentBase
()
{}
void
InitializeComponent
(
const
ComponentSpec
&
spec
)
override
{
name_
=
spec
.
name
();
}
void
InitializeData
(
const
std
::
vector
<
std
::
vector
<
const
TransitionState
*>>
&
states
,
int
max_beam_size
,
InputBatchCache
*
input_data
)
override
{}
void
InitializeTracing
()
override
{}
void
DisableTracing
()
override
{}
bool
IsReady
()
const
override
{
return
true
;
}
string
Name
()
const
override
{
return
name_
;
}
int
BeamSize
()
const
override
{
return
1
;
}
int
BatchSize
()
const
override
{
return
1
;
}
int
StepsTaken
(
int
batch_index
)
const
override
{
return
0
;
}
int
GetBeamIndexAtStep
(
int
step
,
int
current_index
,
int
batch
)
const
override
{
return
0
;
}
int
GetSourceBeamIndex
(
int
current_index
,
int
batch
)
const
override
{
return
0
;
}
bool
AdvanceFromPrediction
(
const
float
*
score_matrix
,
int
num_items
,
int
num_actions
)
override
{
return
true
;
}
void
AdvanceFromOracle
()
override
{}
bool
IsTerminal
()
const
override
{
return
true
;
}
std
::
function
<
int
(
int
,
int
,
int
)
>
GetStepLookupFunction
(
const
string
&
method
)
override
{
return
nullptr
;
}
std
::
vector
<
std
::
vector
<
const
TransitionState
*>>
GetBeam
()
override
{
std
::
vector
<
std
::
vector
<
const
TransitionState
*>>
states
;
return
states
;
}
int
GetFixedFeatures
(
std
::
function
<
int32
*
(
int
)
>
allocate_indices
,
std
::
function
<
int64
*
(
int
)
>
allocate_ids
,
std
::
function
<
float
*
(
int
)
>
allocate_weights
,
int
channel_id
)
const
override
{
return
0
;
}
void
BulkEmbedFixedFeatures
(
int
batch_size_padding
,
int
num_steps_padding
,
int
embedding_size
,
const
vector
<
const
float
*>
&
per_channel_embeddings
,
float
*
embedding_output
)
override
{}
void
BulkEmbedDenseFixedFeatures
(
const
vector
<
const
float
*>
&
per_channel_embeddings
,
float
*
embedding_output
,
int
embedding_output_size
,
int
*
offset_array_output
,
int
offset_array_size
)
override
{}
int
BulkDenseFeatureSize
()
const
override
{
return
0
;
}
int
BulkGetFixedFeatures
(
const
BulkFeatureExtractor
&
extractor
)
override
{
return
0
;
}
std
::
vector
<
LinkFeatures
>
GetRawLinkFeatures
(
int
channel_id
)
const
override
{
std
::
vector
<
LinkFeatures
>
ret
;
return
ret
;
}
std
::
vector
<
std
::
vector
<
std
::
vector
<
Label
>>>
GetOracleLabels
()
const
override
{
std
::
vector
<
std
::
vector
<
std
::
vector
<
Label
>>>
ret
;
return
ret
;
}
void
FinalizeData
()
override
{}
void
ResetComponent
()
override
{}
std
::
vector
<
std
::
vector
<
ComponentTrace
>>
GetTraceProtos
()
const
override
{
std
::
vector
<
std
::
vector
<
ComponentTrace
>>
ret
;
return
ret
;
}
void
AddTranslatedLinkFeaturesToTrace
(
const
std
::
vector
<
LinkFeatures
>
&
features
,
int
channel_id
)
override
{}
string
name_
;
};
}
// namespace dragnn
}
// namespace syntaxnet
#endif // DRAGNN_CORE_TEST_FAKE_COMPONENT_BASE_H_
research/syntaxnet/dragnn/core/test/generic.h
View file @
80178fc6
...
...
@@ -27,7 +27,8 @@
namespace
syntaxnet
{
namespace
test
{
MATCHER_P
(
EqualsProto
,
a
,
"Protos are not equivalent:"
)
{
MATCHER_P
(
EqualsProto
,
a
,
"Protos "
+
string
(
negation
?
"aren't"
:
"are"
)
+
" equivalent:"
)
{
return
a
.
DebugString
()
==
arg
.
DebugString
();
}
...
...
@@ -39,6 +40,16 @@ MATCHER_P(IsErrorWithSubstr, substr,
return
!
arg
.
ok
()
&&
arg
.
error_message
().
find
(
substr
)
!=
string
::
npos
;
}
// Matches an error status whose code and message match |code| and |substr|.
MATCHER_P2
(
IsErrorWithCodeAndSubstr
,
code
,
substr
,
string
(
negation
?
"isn't"
:
"is"
)
+
" an error Status whose code is "
+
::
testing
::
PrintToString
(
code
)
+
" and whose message matches the substring '"
+
::
testing
::
PrintToString
(
substr
)
+
"'"
)
{
return
!
arg
.
ok
()
&&
arg
.
code
()
==
code
&&
arg
.
error_message
().
find
(
substr
)
!=
string
::
npos
;
}
// Returns the prefix for where the test data is stored.
string
GetTestDataPrefix
();
...
...
research/syntaxnet/dragnn/core/test/mock_component.h
View file @
80178fc6
...
...
@@ -22,6 +22,7 @@
#include "dragnn/core/index_translator.h"
#include "dragnn/core/interfaces/component.h"
#include "dragnn/core/interfaces/transition_state.h"
#include "dragnn/core/util/label.h"
#include "dragnn/protos/data.pb.h"
#include "dragnn/protos/spec.pb.h"
#include "syntaxnet/base.h"
...
...
@@ -64,9 +65,15 @@ class MockComponent : public Component {
int
output_array_size
,
const
vector
<
const
float
*>
&
per_channel_embeddings
,
float
*
embedding_output
));
MOCK_METHOD5
(
BulkEmbedDenseFixedFeatures
,
void
(
const
vector
<
const
float
*>
&
per_channel_embeddings
,
float
*
embedding_output
,
int
embedding_output_size
,
int32
*
offset_array_output
,
int
offset_array_size
));
MOCK_CONST_METHOD0
(
BulkDenseFeatureSize
,
int
());
MOCK_CONST_METHOD1
(
GetRawLinkFeatures
,
std
::
vector
<
LinkFeatures
>
(
int
channel_id
));
MOCK_CONST_METHOD0
(
GetOracleLabels
,
std
::
vector
<
std
::
vector
<
int
>>
());
MOCK_CONST_METHOD0
(
GetOracleLabels
,
std
::
vector
<
std
::
vector
<
std
::
vector
<
Label
>>>
());
MOCK_METHOD0
(
ResetComponent
,
void
());
MOCK_METHOD1
(
GetStepLookupFunction
,
std
::
function
<
int
(
int
,
int
,
int
)
>
(
const
string
&
method
));
...
...
Prev
1
2
3
4
5
6
…
8
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment