Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
4364390a
Commit
4364390a
authored
Nov 13, 2017
by
Ivan Bogatyy
Committed by
calberti
Nov 13, 2017
Browse files
Release DRAGNN bulk networks (#2785)
* Release DRAGNN bulk networks
parent
638fd759
Changes
166
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1295 additions
and
261 deletions
+1295
-261
research/syntaxnet/dragnn/core/BUILD
research/syntaxnet/dragnn/core/BUILD
+5
-1
research/syntaxnet/dragnn/core/beam.h
research/syntaxnet/dragnn/core/beam.h
+197
-92
research/syntaxnet/dragnn/core/beam_test.cc
research/syntaxnet/dragnn/core/beam_test.cc
+551
-70
research/syntaxnet/dragnn/core/component_registry.h
research/syntaxnet/dragnn/core/component_registry.h
+8
-3
research/syntaxnet/dragnn/core/compute_session.h
research/syntaxnet/dragnn/core/compute_session.h
+27
-7
research/syntaxnet/dragnn/core/compute_session_impl.cc
research/syntaxnet/dragnn/core/compute_session_impl.cc
+21
-5
research/syntaxnet/dragnn/core/compute_session_impl.h
research/syntaxnet/dragnn/core/compute_session_impl.h
+18
-10
research/syntaxnet/dragnn/core/compute_session_impl_test.cc
research/syntaxnet/dragnn/core/compute_session_impl_test.cc
+73
-14
research/syntaxnet/dragnn/core/compute_session_pool.h
research/syntaxnet/dragnn/core/compute_session_pool.h
+11
-5
research/syntaxnet/dragnn/core/compute_session_pool_test.cc
research/syntaxnet/dragnn/core/compute_session_pool_test.cc
+1
-0
research/syntaxnet/dragnn/core/index_translator.h
research/syntaxnet/dragnn/core/index_translator.h
+3
-3
research/syntaxnet/dragnn/core/input_batch_cache.h
research/syntaxnet/dragnn/core/input_batch_cache.h
+38
-6
research/syntaxnet/dragnn/core/input_batch_cache_test.cc
research/syntaxnet/dragnn/core/input_batch_cache_test.cc
+52
-0
research/syntaxnet/dragnn/core/interfaces/cloneable_transition_state.h
...axnet/dragnn/core/interfaces/cloneable_transition_state.h
+20
-14
research/syntaxnet/dragnn/core/interfaces/component.h
research/syntaxnet/dragnn/core/interfaces/component.h
+17
-7
research/syntaxnet/dragnn/core/interfaces/input_batch.h
research/syntaxnet/dragnn/core/interfaces/input_batch.h
+8
-5
research/syntaxnet/dragnn/core/interfaces/transition_state.h
research/syntaxnet/dragnn/core/interfaces/transition_state.h
+18
-12
research/syntaxnet/dragnn/core/ops/compute_session_op.h
research/syntaxnet/dragnn/core/ops/compute_session_op.h
+3
-3
research/syntaxnet/dragnn/core/ops/dragnn_bulk_op_kernels.cc
research/syntaxnet/dragnn/core/ops/dragnn_bulk_op_kernels.cc
+72
-2
research/syntaxnet/dragnn/core/ops/dragnn_bulk_op_kernels_test.cc
.../syntaxnet/dragnn/core/ops/dragnn_bulk_op_kernels_test.cc
+152
-2
No files found.
research/syntaxnet/dragnn/core/BUILD
View file @
4364390a
...
...
@@ -33,8 +33,9 @@ cc_library(
name
=
"compute_session"
,
hdrs
=
[
"compute_session.h"
],
deps
=
[
":index_translator"
,
":input_batch_cache"
,
"//dragnn/components/util:bulk_feature_extractor"
,
"//dragnn/core:index_translator"
,
"//dragnn/core/interfaces:component"
,
"//dragnn/protos:spec_proto"
,
"//dragnn/protos:trace_proto"
,
...
...
@@ -120,8 +121,10 @@ cc_test(
":compute_session"
,
":compute_session_impl"
,
":compute_session_pool"
,
":input_batch_cache"
,
"//dragnn/components/util:bulk_feature_extractor"
,
"//dragnn/core/interfaces:component"
,
"//dragnn/core/interfaces:input_batch"
,
"//dragnn/core/test:generic"
,
"//dragnn/core/test:mock_component"
,
"//dragnn/core/test:mock_transition_state"
,
...
...
@@ -248,6 +251,7 @@ cc_library(
"//syntaxnet:base"
,
"@org_tensorflow//third_party/eigen3"
,
],
alwayslink
=
1
,
)
# Tensorflow kernel libraries, for use with unit tests.
...
...
research/syntaxnet/dragnn/core/beam.h
View file @
4364390a
...
...
@@ -13,8 +13,8 @@
// limitations under the License.
// =============================================================================
#ifndef
NLP_SAFT_OPENSOURCE_
DRAGNN_CORE_BEAM_H_
#define
NLP_SAFT_OPENSOURCE_
DRAGNN_CORE_BEAM_H_
#ifndef DRAGNN_CORE_BEAM_H_
#define DRAGNN_CORE_BEAM_H_
#include <algorithm>
#include <cmath>
...
...
@@ -43,19 +43,23 @@ class Beam {
static_assert
(
std
::
is_base_of
<
CloneableTransitionState
<
T
>
,
T
>::
value
,
"This class must be instantiated to use a CloneableTransitionState"
);
track_gold_
=
false
;
}
// Sets whether or not the beam should track gold states.
void
SetGoldTracking
(
bool
track_gold
)
{
track_gold_
=
track_gold
;
}
// Sets the Beam functions, as follows:
// bool is_allowed(TransitionState *, int): Return true if transition 'int' is
// allowed for transition state 'TransitionState *'.
// void perform_transition(TransitionState *, int): Performs transition 'int'
// on transition state 'TransitionState *'.
// int oracle_function(TransitionState *): Returns the oracle-
specified action
// for transition state 'TransitionState *'.
//
vector<
int
>
oracle_function(TransitionState *): Returns the oracle-
//
specified actions
for transition state 'TransitionState *'.
void
SetFunctions
(
std
::
function
<
bool
(
T
*
,
int
)
>
is_allowed
,
std
::
function
<
bool
(
T
*
)
>
is_final
,
std
::
function
<
void
(
T
*
,
int
)
>
perform_transition
,
std
::
function
<
int
(
T
*
)
>
oracle_function
)
{
std
::
function
<
vector
<
int
>
(
T
*
)
>
oracle_function
)
{
is_allowed_
=
is_allowed
;
is_final_
=
is_final
;
perform_transition_
=
perform_transition
;
...
...
@@ -74,12 +78,17 @@ class Beam {
for
(
int
i
=
0
;
i
<
beam_
.
size
();
++
i
)
{
previous_beam_indices
.
at
(
i
)
=
beam_
[
i
]
->
ParentBeamIndex
();
beam_
[
i
]
->
SetBeamIndex
(
i
);
// TODO(googleuser): Add gold tracking to component-level state creation.
if
(
!
track_gold_
)
{
beam_
[
i
]
->
SetGold
(
false
);
}
}
beam_index_history_
.
emplace_back
(
previous_beam_indices
);
}
// Advances the Beam from the given transition matrix.
void
AdvanceFromPrediction
(
const
float
transition_matrix
[]
,
int
matrix_length
,
bool
AdvanceFromPrediction
(
const
float
*
transition_matrix
,
int
matrix_length
,
int
num_actions
)
{
// Ensure that the transition matrix is the correct size. All underlying
// states should have the same transition profile, so using the one at 0
...
...
@@ -89,91 +98,20 @@ class Beam {
"state transitions!"
;
if
(
max_size_
==
1
)
{
// In the case where beam size is 1, we can advance by simply finding the
// highest score and advancing the beam state in place.
VLOG
(
2
)
<<
"Beam size is 1. Using fast beam path."
;
int
best_action
=
-
1
;
float
best_score
=
-
INFINITY
;
auto
&
state
=
beam_
[
0
];
for
(
int
action_idx
=
0
;
action_idx
<
num_actions
;
++
action_idx
)
{
if
(
is_allowed_
(
state
.
get
(),
action_idx
)
&&
transition_matrix
[
action_idx
]
>
best_score
)
{
best_score
=
transition_matrix
[
action_idx
];
best_action
=
action_idx
;
}
bool
success
=
FastAdvanceFromPrediction
(
transition_matrix
,
num_actions
);
if
(
!
success
)
{
return
false
;
}
CHECK_GE
(
best_action
,
0
)
<<
"Num actions: "
<<
num_actions
<<
" score[0]: "
<<
transition_matrix
[
0
];
perform_transition_
(
state
.
get
(),
best_action
);
const
float
new_score
=
state
->
GetScore
()
+
best_score
;
state
->
SetScore
(
new_score
);
state
->
SetBeamIndex
(
0
);
}
else
{
// Create the vector of all possible transitions, along with their scores.
std
::
vector
<
Transition
>
candidates
;
// Iterate through all beams, examining all actions for each beam.
for
(
int
beam_idx
=
0
;
beam_idx
<
beam_
.
size
();
++
beam_idx
)
{
const
auto
&
state
=
beam_
[
beam_idx
];
for
(
int
action_idx
=
0
;
action_idx
<
num_actions
;
++
action_idx
)
{
// If the action is allowed, calculate the proposed new score and add
// the candidate action to the vector of all actions at this state.
if
(
is_allowed_
(
state
.
get
(),
action_idx
))
{
Transition
candidate
;
// The matrix is laid out by beam index, with a linear set of
// actions for that index - so beam N's actions start at [nr. of
// actions]*[N].
const
int
matrix_idx
=
action_idx
+
beam_idx
*
num_actions
;
CHECK_LT
(
matrix_idx
,
matrix_length
)
<<
"Matrix index out of bounds!"
;
const
double
score_delta
=
transition_matrix
[
matrix_idx
];
CHECK
(
!
std
::
isnan
(
score_delta
));
candidate
.
source_idx
=
beam_idx
;
candidate
.
action
=
action_idx
;
candidate
.
resulting_score
=
state
->
GetScore
()
+
score_delta
;
candidates
.
emplace_back
(
candidate
);
}
}
}
// Sort the vector of all possible transitions and scores.
const
auto
comparator
=
[](
const
Transition
&
a
,
const
Transition
&
b
)
{
return
a
.
resulting_score
>
b
.
resulting_score
;
};
std
::
stable_sort
(
candidates
.
begin
(),
candidates
.
end
(),
comparator
);
// Apply the top transitions, up to a maximum of 'max_size_'.
std
::
vector
<
std
::
unique_ptr
<
T
>>
new_beam
;
std
::
vector
<
int
>
previous_beam_indices
(
max_size_
,
-
1
);
const
int
beam_size
=
std
::
min
(
max_size_
,
static_cast
<
int
>
(
candidates
.
size
()));
VLOG
(
2
)
<<
"Previous beam size = "
<<
beam_
.
size
();
VLOG
(
2
)
<<
"New beam size = "
<<
beam_size
;
VLOG
(
2
)
<<
"Maximum beam size = "
<<
max_size_
;
for
(
int
i
=
0
;
i
<
beam_size
;
++
i
)
{
// Get the source of the i'th transition.
const
auto
&
transition
=
candidates
[
i
];
VLOG
(
2
)
<<
"Taking transition with score: "
<<
transition
.
resulting_score
<<
" and action: "
<<
transition
.
action
;
VLOG
(
2
)
<<
"transition.source_idx = "
<<
transition
.
source_idx
;
const
auto
&
source
=
beam_
[
transition
.
source_idx
];
// Put the new transition on the new state beam.
auto
new_state
=
source
->
Clone
();
perform_transition_
(
new_state
.
get
(),
transition
.
action
);
new_state
->
SetScore
(
transition
.
resulting_score
);
new_state
->
SetBeamIndex
(
i
);
previous_beam_indices
.
at
(
i
)
=
transition
.
source_idx
;
new_beam
.
emplace_back
(
std
::
move
(
new_state
));
bool
success
=
BeamAdvanceFromPrediction
(
transition_matrix
,
matrix_length
,
num_actions
);
if
(
!
success
)
{
return
false
;
}
beam_
=
std
::
move
(
new_beam
);
beam_index_history_
.
emplace_back
(
previous_beam_indices
);
}
++
num_steps_
;
return
true
;
}
// Advances the Beam from the state oracles.
...
...
@@ -182,7 +120,10 @@ class Beam {
for
(
int
i
=
0
;
i
<
beam_
.
size
();
++
i
)
{
previous_beam_indices
.
at
(
i
)
=
i
;
if
(
is_final_
(
beam_
[
i
].
get
()))
continue
;
const
auto
oracle_label
=
oracle_function_
(
beam_
[
i
].
get
());
// There will always be at least one oracular transition, and taking the
// first returned transition is never worse than any other option.
const
int
oracle_label
=
oracle_function_
(
beam_
[
i
].
get
()).
at
(
0
);
VLOG
(
2
)
<<
"AdvanceFromOracle beam_index:"
<<
i
<<
" oracle_label:"
<<
oracle_label
;
perform_transition_
(
beam_
[
i
].
get
(),
oracle_label
);
...
...
@@ -312,19 +253,180 @@ class Beam {
// Returns the current size of the beam.
const
int
size
()
const
{
return
beam_
.
size
();
}
// Returns true if at least one of the states in the beam is gold.
bool
ContainsGold
()
{
if
(
!
track_gold_
)
{
return
false
;
}
for
(
const
auto
&
state
:
beam_
)
{
if
(
state
->
IsGold
())
{
return
true
;
}
}
return
false
;
}
private:
// Associates an action taken on an index into current_state_ with a score.
friend
void
BM_FastAdvance
(
int
num_iters
,
int
num_transitions
);
friend
void
BM_BeamAdvance
(
int
num_iters
,
int
num_transitions
,
int
max_beam_size
);
// Associates an action taken with its source index.
struct
Transition
{
// The index of the source item.
int
source_idx
;
// The index of the action being taken.
int
action
;
// The score of the full derivation.
double
resulting_score
;
};
// In the case where beam size is 1, we can advance by simply finding the
// highest score and advancing the beam state in place.
bool
FastAdvanceFromPrediction
(
const
float
*
transition_matrix
,
int
num_actions
)
{
CHECK_EQ
(
1
,
max_size_
)
<<
"Using fast advance on invalid beam. This should never happen."
;
VLOG
(
2
)
<<
"Beam size is 1. Using fast beam path."
;
constexpr
int
kNoActionChosen
=
-
1
;
int
best_action
=
kNoActionChosen
;
float
best_score
=
-
INFINITY
;
auto
&
state
=
beam_
[
0
];
for
(
int
action_idx
=
0
;
action_idx
<
num_actions
;
++
action_idx
)
{
if
(
std
::
isnan
(
transition_matrix
[
action_idx
]))
{
LOG
(
ERROR
)
<<
"Found a NaN in the transition matrix! Unable to "
"continue. Num actions: "
<<
num_actions
<<
" index: "
<<
action_idx
;
return
false
;
}
if
(
is_allowed_
(
state
.
get
(),
action_idx
)
&&
transition_matrix
[
action_idx
]
>
best_score
)
{
best_score
=
transition_matrix
[
action_idx
];
best_action
=
action_idx
;
}
}
if
(
best_action
==
kNoActionChosen
)
{
LOG
(
ERROR
)
<<
"No action was chosen! Unable to continue. Num actions: "
<<
num_actions
<<
" score[0]: "
<<
transition_matrix
[
0
];
return
false
;
}
bool
is_gold
=
false
;
if
(
track_gold_
&&
state
->
IsGold
())
{
for
(
const
auto
&
gold_transition
:
oracle_function_
(
state
.
get
()))
{
VLOG
(
3
)
<<
"Examining gold transition "
<<
gold_transition
<<
" for source index 1"
;
if
(
gold_transition
==
best_action
)
{
is_gold
=
true
;
break
;
}
}
}
perform_transition_
(
state
.
get
(),
best_action
);
const
float
new_score
=
state
->
GetScore
()
+
best_score
;
state
->
SetScore
(
new_score
);
state
->
SetBeamIndex
(
0
);
state
->
SetGold
(
is_gold
);
return
true
;
}
// In case the beam size is greater than 1, we need to advance using
// standard beam search.
bool
BeamAdvanceFromPrediction
(
const
float
*
transition_matrix
,
int
matrix_length
,
int
num_actions
)
{
VLOG
(
2
)
<<
"Beam size is "
<<
max_size_
<<
". Using standard beam search."
;
// Keep the multimap sorted high to low. The sort order for
// identical keys is stable.
std
::
multimap
<
float
,
Transition
,
std
::
greater
<
float
>>
candidates
;
float
threshold
=
-
INFINITY
;
// Iterate through all beams, examining all actions for each beam.
for
(
int
beam_idx
=
0
;
beam_idx
<
beam_
.
size
();
++
beam_idx
)
{
const
auto
&
state
=
beam_
[
beam_idx
];
const
float
score
=
state
->
GetScore
();
for
(
int
action_idx
=
0
;
action_idx
<
num_actions
;
++
action_idx
)
{
if
(
is_allowed_
(
state
.
get
(),
action_idx
))
{
// The matrix is laid out by beam index, with a linear set of
// actions for that index - so beam N's actions start at [nr. of
// actions]*[N].
const
int
matrix_idx
=
action_idx
+
beam_idx
*
num_actions
;
CHECK_LT
(
matrix_idx
,
matrix_length
)
<<
"Matrix index out of bounds!"
;
const
float
resulting_score
=
score
+
transition_matrix
[
matrix_idx
];
if
(
std
::
isnan
(
resulting_score
))
{
LOG
(
ERROR
)
<<
"Resulting score was a NaN! Unable to continue. Num "
"actions: "
<<
num_actions
<<
" action index "
<<
action_idx
;
return
false
;
}
if
(
candidates
.
size
()
==
max_size_
)
{
// If the new score is lower than the bottom of the beam, move on.
if
(
resulting_score
<
threshold
)
{
continue
;
}
// Otherwise, remove the bottom of the beam, making space
// for the new candidate.
candidates
.
erase
(
std
::
prev
(
candidates
.
end
()));
}
// Add the new candidate, and update the threshold score.
const
Transition
candidate
{
beam_idx
,
action_idx
};
candidates
.
emplace
(
resulting_score
,
candidate
);
threshold
=
candidates
.
rbegin
()
->
first
;
}
}
}
// Apply the top transitions, up to a maximum of 'max_size_'.
std
::
vector
<
std
::
unique_ptr
<
T
>>
new_beam
;
std
::
vector
<
int
>
previous_beam_indices
(
max_size_
,
-
1
);
const
int
beam_size
=
candidates
.
size
();
new_beam
.
reserve
(
max_size_
);
VLOG
(
2
)
<<
"Previous beam size = "
<<
beam_
.
size
();
VLOG
(
2
)
<<
"New beam size = "
<<
beam_size
;
VLOG
(
2
)
<<
"Maximum beam size = "
<<
max_size_
;
auto
candidate_iterator
=
candidates
.
cbegin
();
for
(
int
i
=
0
;
i
<
beam_size
;
++
i
)
{
// Get the score and source of the i'th transition.
const
float
resulting_score
=
candidate_iterator
->
first
;
const
auto
&
transition
=
candidate_iterator
->
second
;
++
candidate_iterator
;
VLOG
(
2
)
<<
"Taking transition with score: "
<<
resulting_score
<<
" and action: "
<<
transition
.
action
;
VLOG
(
2
)
<<
"transition.source_idx = "
<<
transition
.
source_idx
;
const
auto
&
source
=
beam_
[
transition
.
source_idx
];
// Determine if the transition being taken will result in a gold state.
bool
is_gold
=
false
;
if
(
track_gold_
&&
source
->
IsGold
())
{
for
(
const
auto
&
gold_transition
:
oracle_function_
(
source
.
get
()))
{
VLOG
(
3
)
<<
"Examining gold transition "
<<
gold_transition
<<
" for source index "
<<
transition
.
source_idx
;
if
(
gold_transition
==
transition
.
action
)
{
VLOG
(
2
)
<<
"State from index "
<<
transition
.
source_idx
<<
" is gold."
;
is_gold
=
true
;
break
;
}
}
}
VLOG
(
2
)
<<
"Gold examination complete for source index "
<<
transition
.
source_idx
;
// Put the new transition on the new state beam.
auto
new_state
=
source
->
Clone
();
perform_transition_
(
new_state
.
get
(),
transition
.
action
);
new_state
->
SetScore
(
resulting_score
);
new_state
->
SetBeamIndex
(
i
);
new_state
->
SetGold
(
is_gold
);
previous_beam_indices
.
at
(
i
)
=
transition
.
source_idx
;
new_beam
.
emplace_back
(
std
::
move
(
new_state
));
}
beam_
=
std
::
move
(
new_beam
);
beam_index_history_
.
emplace_back
(
previous_beam_indices
);
return
true
;
}
// The maximum beam size.
int
max_size_
;
...
...
@@ -341,7 +443,7 @@ class Beam {
std
::
function
<
void
(
T
*
,
int
)
>
perform_transition_
;
// Function to provide the oracle action for a given state.
std
::
function
<
int
(
T
*
)
>
oracle_function_
;
std
::
function
<
vector
<
int
>
(
T
*
)
>
oracle_function_
;
// The history of the states in this beam. The vector indexes across steps.
// For every step, there is a vector in the vector. This inner vector denotes
...
...
@@ -355,9 +457,12 @@ class Beam {
// The number of steps taken so far.
int
num_steps_
;
// Whether to track golden states.
bool
track_gold_
;
};
}
// namespace dragnn
}
// namespace syntaxnet
#endif //
NLP_SAFT_OPENSOURCE_
DRAGNN_CORE_BEAM_H_
#endif // DRAGNN_CORE_BEAM_H_
research/syntaxnet/dragnn/core/beam_test.cc
View file @
4364390a
This diff is collapsed.
Click to expand it.
research/syntaxnet/dragnn/core/component_registry.h
View file @
4364390a
...
...
@@ -13,12 +13,17 @@
// limitations under the License.
// =============================================================================
#ifndef
NLP_SAFT_OPENSOURCE_
DRAGNN_CORE_COMPONENT_REGISTRY_H_
#define
NLP_SAFT_OPENSOURCE_
DRAGNN_CORE_COMPONENT_REGISTRY_H_
#ifndef DRAGNN_CORE_COMPONENT_REGISTRY_H_
#define DRAGNN_CORE_COMPONENT_REGISTRY_H_
#include "dragnn/core/interfaces/component.h"
#include "syntaxnet/registry.h"
namespace
syntaxnet
{
// Class registry for DRAGNN components.
DECLARE_SYNTAXNET_CLASS_REGISTRY
(
"DRAGNN Component"
,
dragnn
::
Component
);
}
// namespace syntaxnet
// Macro to add a component to the registry. This macro associates a class with
// its class name as a string, so FooComponent would be associated with the
// string "FooComponent".
...
...
@@ -26,4 +31,4 @@
REGISTER_SYNTAXNET_CLASS_COMPONENT(syntaxnet::dragnn::Component, #component, \
component)
#endif //
NLP_SAFT_OPENSOURCE_
DRAGNN_CORE_COMPONENT_REGISTRY_H_
#endif // DRAGNN_CORE_COMPONENT_REGISTRY_H_
research/syntaxnet/dragnn/core/compute_session.h
View file @
4364390a
...
...
@@ -13,13 +13,14 @@
// limitations under the License.
// =============================================================================
#ifndef
NLP_SAFT_OPENSOURCE_
DRAGNN_CORE_COMPUTE_SESSION_H_
#define
NLP_SAFT_OPENSOURCE_
DRAGNN_CORE_COMPUTE_SESSION_H_
#ifndef DRAGNN_CORE_COMPUTE_SESSION_H_
#define DRAGNN_CORE_COMPUTE_SESSION_H_
#include <string>
#include "dragnn/components/util/bulk_feature_extractor.h"
#include "dragnn/core/index_translator.h"
#include "dragnn/core/input_batch_cache.h"
#include "dragnn/core/interfaces/component.h"
#include "dragnn/protos/spec.pb.h"
#include "dragnn/protos/trace.pb.h"
...
...
@@ -64,10 +65,11 @@ class ComputeSession {
// Advance the given component using the component's oracle.
virtual
void
AdvanceFromOracle
(
const
string
&
component_name
)
=
0
;
// Advance the given component using the given score matrix.
virtual
void
AdvanceFromPrediction
(
const
string
&
component_name
,
const
float
score_matrix
[],
int
score_matrix_length
)
=
0
;
// Advance the given component using the given score matrix, which is
// |num_items| x |num_actions|.
virtual
bool
AdvanceFromPrediction
(
const
string
&
component_name
,
const
float
*
score_matrix
,
int
num_items
,
int
num_actions
)
=
0
;
// Get the input features for the given component and channel. This passes
// through to the relevant Component's GetFixedFeatures() call.
...
...
@@ -84,6 +86,15 @@ class ComputeSession {
virtual
int
BulkGetInputFeatures
(
const
string
&
component_name
,
const
BulkFeatureExtractor
&
extractor
)
=
0
;
// Directly computes the embedding matrix for all channels, advancing the
// component via the oracle until it is terminal. This call takes a vector
// of float embedding matrices, one per channel, in channel order.
virtual
void
BulkEmbedFixedFeatures
(
const
string
&
component_name
,
int
batch_size_padding
,
int
num_steps_padding
,
int
output_array_size
,
const
vector
<
const
float
*>
&
per_channel_embeddings
,
float
*
embedding_output
)
=
0
;
// Get the input features for the given component and channel. This function
// can return empty LinkFeatures protos, which represent unused padding slots
// in the output weight tensor.
...
...
@@ -111,6 +122,10 @@ class ComputeSession {
// Provides the ComputeSession with a batch of data to compute.
virtual
void
SetInputData
(
const
std
::
vector
<
string
>
&
data
)
=
0
;
// Like SetInputData(), but accepts an InputBatchCache directly, potentially
// bypassing de-serialization.
virtual
void
SetInputBatchCache
(
std
::
unique_ptr
<
InputBatchCache
>
batch
)
=
0
;
// Resets all components owned by this ComputeSession.
virtual
void
ResetSession
()
=
0
;
...
...
@@ -127,9 +142,14 @@ class ComputeSession {
// validate correct construction of translators in tests.
virtual
const
std
::
vector
<
const
IndexTranslator
*>
Translators
(
const
string
&
component_name
)
const
=
0
;
// Get a given component. CHECK-fail if the component's IsReady method
// returns false.
virtual
Component
*
GetReadiedComponent
(
const
string
&
component_name
)
const
=
0
;
};
}
// namespace dragnn
}
// namespace syntaxnet
#endif //
NLP_SAFT_OPENSOURCE_
DRAGNN_CORE_COMPUTE_SESSION_H_
#endif // DRAGNN_CORE_COMPUTE_SESSION_H_
research/syntaxnet/dragnn/core/compute_session_impl.cc
View file @
4364390a
...
...
@@ -161,11 +161,11 @@ void ComputeSessionImpl::AdvanceFromOracle(const string &component_name) {
GetReadiedComponent
(
component_name
)
->
AdvanceFromOracle
();
}
void
ComputeSessionImpl
::
AdvanceFromPrediction
(
const
string
&
component_name
,
const
float
score_matrix
[]
,
int
score_matrix_length
)
{
GetReadiedComponent
(
component_name
)
->
AdvanceFromPrediction
(
score_matrix
,
score_matrix_length
);
bool
ComputeSessionImpl
::
AdvanceFromPrediction
(
const
string
&
component_name
,
const
float
*
score_matrix
,
int
num_items
,
int
num_actions
)
{
return
GetReadiedComponent
(
component_name
)
->
AdvanceFromPrediction
(
score_matrix
,
num_items
,
num_actions
);
}
int
ComputeSessionImpl
::
GetInputFeatures
(
...
...
@@ -182,6 +182,16 @@ int ComputeSessionImpl::BulkGetInputFeatures(
return
GetReadiedComponent
(
component_name
)
->
BulkGetFixedFeatures
(
extractor
);
}
void
ComputeSessionImpl
::
BulkEmbedFixedFeatures
(
const
string
&
component_name
,
int
batch_size_padding
,
int
num_steps_padding
,
int
output_array_size
,
const
vector
<
const
float
*>
&
per_channel_embeddings
,
float
*
embedding_output
)
{
return
GetReadiedComponent
(
component_name
)
->
BulkEmbedFixedFeatures
(
batch_size_padding
,
num_steps_padding
,
output_array_size
,
per_channel_embeddings
,
embedding_output
);
}
std
::
vector
<
LinkFeatures
>
ComputeSessionImpl
::
GetTranslatedLinkFeatures
(
const
string
&
component_name
,
int
channel_id
)
{
auto
*
component
=
GetReadiedComponent
(
component_name
);
...
...
@@ -288,6 +298,11 @@ void ComputeSessionImpl::SetInputData(const std::vector<string> &data) {
input_data_
.
reset
(
new
InputBatchCache
(
data
));
}
void
ComputeSessionImpl
::
SetInputBatchCache
(
std
::
unique_ptr
<
InputBatchCache
>
batch
)
{
input_data_
=
std
::
move
(
batch
);
}
void
ComputeSessionImpl
::
ResetSession
()
{
// Reset all component states.
for
(
auto
&
component_pair
:
components_
)
{
...
...
@@ -308,6 +323,7 @@ const std::vector<const IndexTranslator *> ComputeSessionImpl::Translators(
const
string
&
component_name
)
const
{
auto
translators
=
GetTranslators
(
component_name
);
std
::
vector
<
const
IndexTranslator
*>
const_translators
;
const_translators
.
reserve
(
translators
.
size
());
for
(
const
auto
&
translator
:
translators
)
{
const_translators
.
push_back
(
translator
);
}
...
...
research/syntaxnet/dragnn/core/compute_session_impl.h
View file @
4364390a
...
...
@@ -13,8 +13,8 @@
// limitations under the License.
// =============================================================================
#ifndef
NLP_SAFT_OPENSOURCE_
DRAGNN_CORE_COMPUTE_SESSION_IMPL_H_
#define
NLP_SAFT_OPENSOURCE_
DRAGNN_CORE_COMPUTE_SESSION_IMPL_H_
#ifndef DRAGNN_CORE_COMPUTE_SESSION_IMPL_H_
#define DRAGNN_CORE_COMPUTE_SESSION_IMPL_H_
#include <memory>
...
...
@@ -55,9 +55,9 @@ class ComputeSessionImpl : public ComputeSession {
void
AdvanceFromOracle
(
const
string
&
component_name
)
override
;
void
AdvanceFromPrediction
(
const
string
&
component_name
,
const
float
score_matrix
[]
,
int
score_matrix_length
)
override
;
bool
AdvanceFromPrediction
(
const
string
&
component_name
,
const
float
*
score_matrix
,
int
num_items
,
int
num_actions
)
override
;
int
GetInputFeatures
(
const
string
&
component_name
,
std
::
function
<
int32
*
(
int
)
>
allocate_indices
,
...
...
@@ -68,6 +68,12 @@ class ComputeSessionImpl : public ComputeSession {
int
BulkGetInputFeatures
(
const
string
&
component_name
,
const
BulkFeatureExtractor
&
extractor
)
override
;
void
BulkEmbedFixedFeatures
(
const
string
&
component_name
,
int
batch_size_padding
,
int
num_steps_padding
,
int
output_array_size
,
const
vector
<
const
float
*>
&
per_channel_embeddings
,
float
*
embedding_output
)
override
;
std
::
vector
<
LinkFeatures
>
GetTranslatedLinkFeatures
(
const
string
&
component_name
,
int
channel_id
)
override
;
...
...
@@ -84,6 +90,8 @@ class ComputeSessionImpl : public ComputeSession {
void
SetInputData
(
const
std
::
vector
<
string
>
&
data
)
override
;
void
SetInputBatchCache
(
std
::
unique_ptr
<
InputBatchCache
>
batch
)
override
;
void
ResetSession
()
override
;
void
SetTracing
(
bool
tracing_on
)
override
;
...
...
@@ -95,14 +103,14 @@ class ComputeSessionImpl : public ComputeSession {
const
std
::
vector
<
const
IndexTranslator
*>
Translators
(
const
string
&
component_name
)
const
override
;
// Get a given component. CHECK-fail if the component's IsReady method
// returns false.
Component
*
GetReadiedComponent
(
const
string
&
component_name
)
const
override
;
private:
// Get a given component. Fails if the component is not found.
Component
*
GetComponent
(
const
string
&
component_name
)
const
;
// Get a given component. CHECK-fail if the component's IsReady method
// returns false.
Component
*
GetReadiedComponent
(
const
string
&
component_name
)
const
;
// Get the index translators for the given component.
const
std
::
vector
<
IndexTranslator
*>
&
GetTranslators
(
const
string
&
component_name
)
const
;
...
...
@@ -154,4 +162,4 @@ class ComputeSessionImpl : public ComputeSession {
}
// namespace dragnn
}
// namespace syntaxnet
#endif //
NLP_SAFT_OPENSOURCE_
DRAGNN_CORE_COMPUTE_SESSION_IMPL_H_
#endif // DRAGNN_CORE_COMPUTE_SESSION_IMPL_H_
research/syntaxnet/dragnn/core/compute_session_impl_test.cc
View file @
4364390a
...
...
@@ -22,7 +22,9 @@
#include "dragnn/core/component_registry.h"
#include "dragnn/core/compute_session.h"
#include "dragnn/core/compute_session_pool.h"
#include "dragnn/core/input_batch_cache.h"
#include "dragnn/core/interfaces/component.h"
#include "dragnn/core/interfaces/input_batch.h"
#include "dragnn/core/test/generic.h"
#include "dragnn/core/test/mock_component.h"
#include "dragnn/core/test/mock_transition_state.h"
...
...
@@ -65,8 +67,10 @@ class TestComponentType1 : public Component {
int
GetSourceBeamIndex
(
int
current_index
,
int
batch
)
const
override
{
return
0
;
}
void
AdvanceFromPrediction
(
const
float
transition_matrix
[],
int
matrix_length
)
override
{}
bool
AdvanceFromPrediction
(
const
float
*
score_matrix
,
int
num_items
,
int
num_actions
)
override
{
return
true
;
}
void
AdvanceFromOracle
()
override
{}
bool
IsTerminal
()
const
override
{
return
true
;
}
std
::
function
<
int
(
int
,
int
,
int
)
>
GetStepLookupFunction
(
...
...
@@ -83,6 +87,10 @@ class TestComponentType1 : public Component {
int
channel_id
)
const
override
{
return
0
;
}
void
BulkEmbedFixedFeatures
(
int
batch_size_padding
,
int
num_steps_padding
,
int
embedding_size
,
const
vector
<
const
float
*>
&
per_channel_embeddings
,
float
*
embedding_output
)
override
{}
int
BulkGetFixedFeatures
(
const
BulkFeatureExtractor
&
extractor
)
override
{
return
0
;
}
...
...
@@ -133,8 +141,10 @@ class TestComponentType2 : public Component {
int
GetSourceBeamIndex
(
int
current_index
,
int
batch
)
const
override
{
return
0
;
}
void
AdvanceFromPrediction
(
const
float
transition_matrix
[],
int
matrix_length
)
override
{}
bool
AdvanceFromPrediction
(
const
float
*
score_matrix
,
int
num_items
,
int
num_actions
)
override
{
return
true
;
}
void
AdvanceFromOracle
()
override
{}
bool
IsTerminal
()
const
override
{
return
true
;
}
std
::
function
<
int
(
int
,
int
,
int
)
>
GetStepLookupFunction
(
...
...
@@ -151,6 +161,10 @@ class TestComponentType2 : public Component {
int
channel_id
)
const
override
{
return
0
;
}
void
BulkEmbedFixedFeatures
(
int
batch_size_padding
,
int
num_steps_padding
,
int
embedding_size
,
const
vector
<
const
float
*>
&
per_channel_embeddings
,
float
*
embedding_output
)
override
{}
int
BulkGetFixedFeatures
(
const
BulkFeatureExtractor
&
extractor
)
override
{
return
0
;
}
...
...
@@ -201,8 +215,14 @@ class UnreadyComponent : public Component {
int
GetSourceBeamIndex
(
int
current_index
,
int
batch
)
const
override
{
return
0
;
}
void
AdvanceFromPrediction
(
const
float
transition_matrix
[],
int
matrix_length
)
override
{}
bool
AdvanceFromPrediction
(
const
float
*
score_matrix
,
int
num_items
,
int
num_actions
)
override
{
return
true
;
}
void
BulkEmbedFixedFeatures
(
int
batch_size_padding
,
int
num_steps_padding
,
int
embedding_size
,
const
vector
<
const
float
*>
&
per_channel_embeddings
,
float
*
embedding_output
)
override
{}
void
AdvanceFromOracle
()
override
{}
bool
IsTerminal
()
const
override
{
return
false
;
}
std
::
function
<
int
(
int
,
int
,
int
)
>
GetStepLookupFunction
(
...
...
@@ -254,6 +274,18 @@ class ComputeSessionImplTestPoolAccessor {
}
};
// An InputBatch that uses the serialized data directly.
class
IdentityBatch
:
public
InputBatch
{
public:
// Implements InputBatch.
void
SetData
(
const
std
::
vector
<
string
>
&
data
)
override
{
data_
=
data
;
}
int
GetSize
()
const
override
{
return
data_
.
size
();
}
const
std
::
vector
<
string
>
GetSerializedData
()
const
override
{
return
data_
;
}
private:
std
::
vector
<
string
>
data_
;
// the batch data
};
// *****************************************************************************
// Tests begin here.
// *****************************************************************************
...
...
@@ -739,7 +771,7 @@ TEST(ComputeSessionImplTest, InitializesComponentWithSource) {
EXPECT_CALL
(
*
mock_components
[
"component_one"
],
GetBeam
())
.
WillOnce
(
Return
(
beam
));
// Expect that the second component will rec
i
eve that beam.
// Expect that the second component will rece
i
ve that beam.
EXPECT_CALL
(
*
mock_components
[
"component_two"
],
InitializeData
(
beam
,
kMaxBeamSize
,
NotNull
()));
...
...
@@ -899,7 +931,7 @@ TEST(ComputeSessionImplTest, SetTracingPropagatesToAllComponents) {
EXPECT_CALL
(
*
mock_components
[
"component_one"
],
GetBeam
())
.
WillOnce
(
Return
(
beam
));
// Expect that the second component will rec
i
eve that beam, and then its
// Expect that the second component will rece
i
ve that beam, and then its
// tracing will be initialized.
EXPECT_CALL
(
*
mock_components
[
"component_two"
],
InitializeData
(
beam
,
kMaxBeamSize
,
NotNull
()));
...
...
@@ -1084,12 +1116,12 @@ TEST(ComputeSessionImplTest, InterfacePassesThrough) {
session
->
AdvanceFromOracle
(
"component_one"
);
// AdvanceFromPrediction()
const
expr
int
k
ScoreMatrixLength
=
3
;
const
float
score_matrix
[
kScoreMatrixLength
]
=
{
1.0
,
2.3
,
4.5
};
const
int
k
NumActions
=
1
;
const
float
score_matrix
[]
=
{
1.0
,
2.3
,
4.5
};
EXPECT_CALL
(
*
mock_components
[
"component_one"
],
AdvanceFromPrediction
(
score_matrix
,
kScoreMatrixLength
));
session
->
AdvanceFromPrediction
(
"component_one"
,
score_matrix
,
k
ScoreMatrixLength
);
AdvanceFromPrediction
(
score_matrix
,
batch_size
,
kNumActions
));
session
->
AdvanceFromPrediction
(
"component_one"
,
score_matrix
,
batch_size
,
k
NumActions
);
// GetFixedFeatures
auto
allocate_indices
=
[](
int
size
)
->
int32
*
{
return
nullptr
;
};
...
...
@@ -1109,6 +1141,11 @@ TEST(ComputeSessionImplTest, InterfacePassesThrough) {
.
WillOnce
(
Return
(
0
));
EXPECT_EQ
(
0
,
session
->
BulkGetInputFeatures
(
"component_one"
,
extractor
));
// BulkEmbedFixedFeatures
EXPECT_CALL
(
*
mock_components
[
"component_one"
],
BulkEmbedFixedFeatures
(
1
,
2
,
3
,
_
,
_
));
session
->
BulkEmbedFixedFeatures
(
"component_one"
,
1
,
2
,
3
,
{
nullptr
},
nullptr
);
// EmitOracleLabels()
std
::
vector
<
std
::
vector
<
int
>>
oracle_labels
=
{{
0
,
1
},
{
2
,
3
}};
EXPECT_CALL
(
*
mock_components
[
"component_one"
],
GetOracleLabels
())
...
...
@@ -1154,7 +1191,7 @@ TEST(ComputeSessionImplTest, InterfaceRequiresReady) {
constexpr
int
kScoreMatrixLength
=
3
;
const
float
score_matrix
[
kScoreMatrixLength
]
=
{
1.0
,
2.3
,
4.5
};
EXPECT_DEATH
(
session
->
AdvanceFromPrediction
(
"component_one"
,
score_matrix
,
kScoreMatrixLength
),
kScoreMatrixLength
,
1
),
"without first initializing it"
);
constexpr
int
kArbitraryChannelId
=
3
;
EXPECT_DEATH
(
session
->
GetInputFeatures
(
"component_one"
,
nullptr
,
nullptr
,
...
...
@@ -1163,10 +1200,32 @@ TEST(ComputeSessionImplTest, InterfaceRequiresReady) {
BulkFeatureExtractor
extractor
(
nullptr
,
nullptr
,
nullptr
,
false
,
0
,
0
);
EXPECT_DEATH
(
session
->
BulkGetInputFeatures
(
"component_one"
,
extractor
),
"without first initializing it"
);
EXPECT_DEATH
(
session
->
BulkEmbedFixedFeatures
(
"component_one"
,
0
,
0
,
0
,
{
nullptr
},
nullptr
),
"without first initializing it"
);
EXPECT_DEATH
(
session
->
GetTranslatedLinkFeatures
(
"component_one"
,
kArbitraryChannelId
),
"without first initializing it"
);
}
TEST
(
ComputeSessionImplTest
,
SetInputBatchCache
)
{
// Use empty protos since we won't interact with components.
MasterSpec
spec
;
GridPoint
hyperparams
;
ComputeSessionPool
pool
(
spec
,
hyperparams
);
auto
session
=
pool
.
GetSession
();
// Initialize a cached IdentityBatch.
const
std
::
vector
<
string
>
data
=
{
"foo"
,
"bar"
,
"baz"
};
std
::
unique_ptr
<
InputBatchCache
>
input_batch_cache
(
new
InputBatchCache
(
data
));
input_batch_cache
->
GetAs
<
IdentityBatch
>
();
// Inject the cache into the session.
session
->
SetInputBatchCache
(
std
::
move
(
input_batch_cache
));
// Check that the injected batch can be retrieved.
EXPECT_EQ
(
session
->
GetSerializedPredictions
(),
data
);
}
}
// namespace dragnn
}
// namespace syntaxnet
research/syntaxnet/dragnn/core/compute_session_pool.h
View file @
4364390a
...
...
@@ -13,8 +13,8 @@
// limitations under the License.
// =============================================================================
#ifndef
NLP_SAFT_OPENSOURCE_
DRAGNN_CORE_COMPUTE_SESSION_POOL_H_
#define
NLP_SAFT_OPENSOURCE_
DRAGNN_CORE_COMPUTE_SESSION_POOL_H_
#ifndef DRAGNN_CORE_COMPUTE_SESSION_POOL_H_
#define DRAGNN_CORE_COMPUTE_SESSION_POOL_H_
#include <memory>
...
...
@@ -29,14 +29,14 @@ namespace dragnn {
class
ComputeSessionPool
{
public:
// Create a ComputeSessionPool that creates ComputeSessions for the given
// Create
s
a ComputeSessionPool that creates ComputeSessions for the given
// MasterSpec and hyperparameters.
ComputeSessionPool
(
const
MasterSpec
&
master_spec
,
const
GridPoint
&
hyperparams
);
virtual
~
ComputeSessionPool
();
// Get a ComputeSession. This function will attempt to use an already-created
// Get
s
a ComputeSession. This function will attempt to use an already-created
// ComputeSession, but if none are available a new one will be created.
std
::
unique_ptr
<
ComputeSession
>
GetSession
();
...
...
@@ -49,6 +49,12 @@ class ComputeSessionPool {
return
num_unique_sessions_
-
sessions_
.
size
();
}
// Returns the number of unique sessions that have been created.
int
num_unique_sessions
()
{
return
num_unique_sessions_
;
}
// Returns a reference to the underlying spec for this pool.
const
MasterSpec
&
GetSpec
()
const
{
return
master_spec_
;
}
private:
friend
class
ComputeSessionImplTestPoolAccessor
;
friend
class
ComputeSessionPoolTestPoolAccessor
;
...
...
@@ -99,4 +105,4 @@ class ComputeSessionPool {
}
// namespace dragnn
}
// namespace syntaxnet
#endif //
NLP_SAFT_OPENSOURCE_
DRAGNN_CORE_COMPUTE_SESSION_POOL_H_
#endif // DRAGNN_CORE_COMPUTE_SESSION_POOL_H_
research/syntaxnet/dragnn/core/compute_session_pool_test.cc
View file @
4364390a
...
...
@@ -207,6 +207,7 @@ TEST(ComputeSessionPoolTest, SupportsMultithreadedAccess) {
std
::
vector
<
std
::
unique_ptr
<
tensorflow
::
Thread
>>
request_threads
;
constexpr
int
kNumThreadsToTest
=
100
;
request_threads
.
reserve
(
kNumThreadsToTest
);
for
(
int
i
=
0
;
i
<
kNumThreadsToTest
;
++
i
)
{
request_threads
.
push_back
(
std
::
unique_ptr
<
tensorflow
::
Thread
>
(
tensorflow
::
Env
::
Default
()
->
StartThread
(
...
...
research/syntaxnet/dragnn/core/index_translator.h
View file @
4364390a
...
...
@@ -13,8 +13,8 @@
// limitations under the License.
// =============================================================================
#ifndef
NLP_SAFT_OPENSOURCE_
DRAGNN_CORE_INDEX_TRANSLATOR_H_
#define
NLP_SAFT_OPENSOURCE_
DRAGNN_CORE_INDEX_TRANSLATOR_H_
#ifndef DRAGNN_CORE_INDEX_TRANSLATOR_H_
#define DRAGNN_CORE_INDEX_TRANSLATOR_H_
#include <memory>
#include <vector>
...
...
@@ -80,4 +80,4 @@ class IndexTranslator {
}
// namespace dragnn
}
// namespace syntaxnet
#endif //
NLP_SAFT_OPENSOURCE_
DRAGNN_CORE_INDEX_TRANSLATOR_H_
#endif // DRAGNN_CORE_INDEX_TRANSLATOR_H_
research/syntaxnet/dragnn/core/input_batch_cache.h
View file @
4364390a
...
...
@@ -13,12 +13,15 @@
// limitations under the License.
// =============================================================================
#ifndef
NLP_SAFT_OPENSOURCE_
DRAGNN_CORE_INPUT_BATCH_CACHE_H_
#define
NLP_SAFT_OPENSOURCE_
DRAGNN_CORE_INPUT_BATCH_CACHE_H_
#ifndef DRAGNN_CORE_INPUT_BATCH_CACHE_H_
#define DRAGNN_CORE_INPUT_BATCH_CACHE_H_
#include <memory>
#include <string>
#include <type_traits>
#include <typeindex>
#include <typeinfo>
#include <utility>
#include "dragnn/core/interfaces/input_batch.h"
#include "tensorflow/core/platform/logging.h"
...
...
@@ -42,6 +45,18 @@ class InputBatchCache {
explicit
InputBatchCache
(
const
std
::
vector
<
string
>
&
data
)
:
stored_type_
(
std
::
type_index
(
typeid
(
void
))),
source_data_
(
data
)
{}
// Creates a InputBatchCache from the |batch|. InputBatchSubclass must be a
// strict subclass of InputBatch, and |batch| must be non-null. All calls to
// GetAs must match InputBatchSubclass.
template
<
class
InputBatchSubclass
>
explicit
InputBatchCache
(
std
::
unique_ptr
<
InputBatchSubclass
>
batch
)
:
stored_type_
(
std
::
type_index
(
typeid
(
InputBatchSubclass
))),
converted_data_
(
std
::
move
(
batch
))
{
static_assert
(
IsStrictInputBatchSubclass
<
InputBatchSubclass
>
(),
"InputBatchCache requires a strict subclass of InputBatch"
);
CHECK
(
converted_data_
)
<<
"Cannot initialize from a null InputBatch"
;
}
// Adds a single string to the cache. Only useable before GetAs() has been
// called.
void
AddData
(
const
string
&
data
)
{
...
...
@@ -52,10 +67,14 @@ class InputBatchCache {
}
// Converts the stored strings into protos and return them in a specific
// InputBatch subclass. T should always be of type InputBatch. After this
// method is called once, all further calls must be of the same data type.
// InputBatch subclass. T should always be a strict subclass of InputBatch.
// After this method is called once, all further calls must be of the same
// data type.
template
<
class
T
>
T
*
GetAs
()
{
static_assert
(
IsStrictInputBatchSubclass
<
T
>
(),
"GetAs<T>() requires that T is a strict subclass of InputBatch"
);
if
(
!
converted_data_
)
{
stored_type_
=
std
::
type_index
(
typeid
(
T
));
converted_data_
.
reset
(
new
T
());
...
...
@@ -69,14 +88,27 @@ class InputBatchCache {
return
dynamic_cast
<
T
*>
(
converted_data_
.
get
());
}
// Returns the size of the batch. Requires that GetAs() has been called.
int
Size
()
const
{
CHECK
(
converted_data_
)
<<
"Cannot return batch size without data."
;
return
converted_data_
->
GetSize
();
}
// Returns the serialized representation of the data held in the input batch
// object within this cache.
// object within this cache.
Requires that GetAs() has been called.
const
std
::
vector
<
string
>
SerializedData
()
const
{
CHECK
(
converted_data_
)
<<
"Cannot return batch without data."
;
return
converted_data_
->
GetSerializedData
();
}
private:
// Returns true if InputBatchSubclass is a strict subclass of InputBatch.
template
<
class
InputBatchSubclass
>
static
constexpr
bool
IsStrictInputBatchSubclass
()
{
return
std
::
is_base_of
<
InputBatch
,
InputBatchSubclass
>::
value
&&
!
std
::
is_same
<
InputBatch
,
InputBatchSubclass
>::
value
;
}
// The typeid of the stored data.
std
::
type_index
stored_type_
;
...
...
@@ -90,4 +122,4 @@ class InputBatchCache {
}
// namespace dragnn
}
// namespace syntaxnet
#endif //
NLP_SAFT_OPENSOURCE_
DRAGNN_CORE_INPUT_BATCH_CACHE_H_
#endif // DRAGNN_CORE_INPUT_BATCH_CACHE_H_
research/syntaxnet/dragnn/core/input_batch_cache_test.cc
View file @
4364390a
...
...
@@ -32,6 +32,8 @@ class StringData : public InputBatch {
}
}
int
GetSize
()
const
override
{
return
data_
.
size
();
}
const
std
::
vector
<
string
>
GetSerializedData
()
const
override
{
return
data_
;
}
std
::
vector
<
string
>
*
data
()
{
return
&
data_
;
}
...
...
@@ -50,6 +52,8 @@ class DifferentStringData : public InputBatch {
}
}
int
GetSize
()
const
override
{
return
data_
.
size
();
}
const
std
::
vector
<
string
>
GetSerializedData
()
const
override
{
return
data_
;
}
std
::
vector
<
string
>
*
data
()
{
return
&
data_
;
}
...
...
@@ -58,6 +62,11 @@ class DifferentStringData : public InputBatch {
std
::
vector
<
string
>
data_
;
};
// Expects that two pointers have the same address.
void
ExpectSameAddress
(
const
void
*
pointer1
,
const
void
*
pointer2
)
{
EXPECT_EQ
(
pointer1
,
pointer2
);
}
TEST
(
InputBatchCacheTest
,
ConvertsSingleInput
)
{
string
test_string
=
"Foo"
;
InputBatchCache
generic_set
(
test_string
);
...
...
@@ -118,5 +127,48 @@ TEST(InputBatchCacheTest, ConvertsAddedInputDiesAfterGetAs) {
"after the cache has been converted"
);
}
TEST
(
InputBatchCacheTest
,
SerializedDataAndSize
)
{
InputBatchCache
generic_set
;
generic_set
.
AddData
(
"Foo"
);
generic_set
.
AddData
(
"Bar"
);
generic_set
.
GetAs
<
StringData
>
();
const
std
::
vector
<
string
>
expected_data
=
{
"Foo_converted"
,
"Bar_converted"
};
EXPECT_EQ
(
expected_data
,
generic_set
.
SerializedData
());
EXPECT_EQ
(
2
,
generic_set
.
Size
());
}
TEST
(
InputBatchCacheTest
,
InitializeFromInputBatch
)
{
const
std
::
vector
<
string
>
kInputData
=
{
"foo"
,
"bar"
,
"baz"
};
const
std
::
vector
<
string
>
kExpectedData
=
{
"foo_converted"
,
//
"bar_converted"
,
//
"baz_converted"
};
std
::
unique_ptr
<
StringData
>
string_data
(
new
StringData
());
string_data
->
SetData
(
kInputData
);
const
StringData
*
string_data_ptr
=
string_data
.
get
();
InputBatchCache
generic_set
(
std
::
move
(
string_data
));
auto
data
=
generic_set
.
GetAs
<
StringData
>
();
ExpectSameAddress
(
string_data_ptr
,
data
);
EXPECT_EQ
(
data
->
GetSize
(),
3
);
EXPECT_EQ
(
data
->
GetSerializedData
(),
kExpectedData
);
EXPECT_EQ
(
*
data
->
data
(),
kExpectedData
);
// AddData() shouldn't work since the cache is already populated.
EXPECT_DEATH
(
generic_set
.
AddData
(
"YOU MAY NOT DO THIS AND IT WILL DIE."
),
"after the cache has been converted"
);
// GetAs() shouldn't work with a different type.
EXPECT_DEATH
(
generic_set
.
GetAs
<
DifferentStringData
>
(),
"Attempted to convert to two object types!"
);
}
TEST
(
InputBatchCacheTest
,
CannotInitializeFromNullInputBatch
)
{
EXPECT_DEATH
(
InputBatchCache
(
std
::
unique_ptr
<
StringData
>
()),
"Cannot initialize from a null InputBatch"
);
}
}
// namespace dragnn
}
// namespace syntaxnet
research/syntaxnet/dragnn/core/interfaces/cloneable_transition_state.h
View file @
4364390a
...
...
@@ -13,8 +13,8 @@
// limitations under the License.
// =============================================================================
#ifndef
NLP_SAFT_OPENSOURCE_
DRAGNN_CORE_INTERFACES_CLONEABLE_TRANSITION_STATE_H_
#define
NLP_SAFT_OPENSOURCE_
DRAGNN_CORE_INTERFACES_CLONEABLE_TRANSITION_STATE_H_
#ifndef DRAGNN_CORE_INTERFACES_CLONEABLE_TRANSITION_STATE_H_
#define DRAGNN_CORE_INTERFACES_CLONEABLE_TRANSITION_STATE_H_
#include <memory>
#include <vector>
...
...
@@ -33,26 +33,32 @@ class CloneableTransitionState : public TransitionState {
public:
~
CloneableTransitionState
<
T
>
()
override
{}
// Initialize this TransitionState from a previous TransitionState. The
// Initialize
s
this TransitionState from a previous TransitionState. The
// ParentBeamIndex is the location of that previous TransitionState in the
// provided beam.
void
Init
(
const
TransitionState
&
parent
)
override
=
0
;
// Return the beam index of the state passed into the initializer of this
// Return
s
the beam index of the state passed into the initializer of this
// TransitionState.
const
int
ParentBeamIndex
()
const
override
=
0
;
int
ParentBeamIndex
()
const
override
=
0
;
// Get the current beam index for this state.
const
int
GetBeamIndex
()
const
override
=
0
;
// Get
s
the current beam index for this state.
int
GetBeamIndex
()
const
override
=
0
;
// Set the current beam index for this state.
void
SetBeamIndex
(
const
int
index
)
override
=
0
;
// Set
s
the current beam index for this state.
void
SetBeamIndex
(
int
index
)
override
=
0
;
// Get the score associated with this transition state.
const
float
GetScore
()
const
override
=
0
;
// Get
s
the score associated with this transition state.
float
GetScore
()
const
override
=
0
;
// Set the score associated with this transition state.
void
SetScore
(
const
float
score
)
override
=
0
;
// Sets the score associated with this transition state.
void
SetScore
(
float
score
)
override
=
0
;
// Gets the gold-ness of this state (whether it is on the oracle path)
bool
IsGold
()
const
override
=
0
;
// Sets the gold-ness of this state.
void
SetGold
(
bool
is_gold
)
override
=
0
;
// Depicts this state as an HTML-language string.
string
HTMLRepresentation
()
const
override
=
0
;
...
...
@@ -64,4 +70,4 @@ class CloneableTransitionState : public TransitionState {
}
// namespace dragnn
}
// namespace syntaxnet
#endif //
NLP_SAFT_OPENSOURCE_
DRAGNN_CORE_INTERFACES_CLONEABLE_TRANSITION_STATE_H_
#endif // DRAGNN_CORE_INTERFACES_CLONEABLE_TRANSITION_STATE_H_
research/syntaxnet/dragnn/core/interfaces/component.h
View file @
4364390a
...
...
@@ -13,8 +13,8 @@
// limitations under the License.
// =============================================================================
#ifndef
NLP_SAFT_OPENSOURCE_
DRAGNN_CORE_INTERFACES_COMPONENT_H_
#define
NLP_SAFT_OPENSOURCE_
DRAGNN_CORE_INTERFACES_COMPONENT_H_
#ifndef DRAGNN_CORE_INTERFACES_COMPONENT_H_
#define DRAGNN_CORE_INTERFACES_COMPONENT_H_
#include <vector>
...
...
@@ -83,11 +83,13 @@ class Component : public RegisterableClass<Component> {
virtual
std
::
function
<
int
(
int
,
int
,
int
)
>
GetStepLookupFunction
(
const
string
&
method
)
=
0
;
// Advances this component from the given transition matrix.
virtual
void
AdvanceFromPrediction
(
const
float
transition_matrix
[],
int
transition_matrix_length
)
=
0
;
// Advances this component from the given transition matrix, which is
// |num_items| x |num_actions|.
virtual
bool
AdvanceFromPrediction
(
const
float
*
score_matrix
,
int
num_items
,
int
num_actions
)
=
0
;
// Advances this component from the state oracles.
// Advances this component from the state oracles. There is no return from
// this, since it should always succeed.
virtual
void
AdvanceFromOracle
()
=
0
;
// Returns true if all states within this component are terminal.
...
...
@@ -110,6 +112,14 @@ class Component : public RegisterableClass<Component> {
// BulkFeatureExtractor object to contain the functors and other information.
virtual
int
BulkGetFixedFeatures
(
const
BulkFeatureExtractor
&
extractor
)
=
0
;
// Directly computes the embedding matrix for all channels, advancing the
// component via the oracle until it is terminal. This call takes a vector
// of EmbeddingMatrix structs, one per channel, in channel order.
virtual
void
BulkEmbedFixedFeatures
(
int
batch_size_padding
,
int
num_steps_padding
,
int
output_array_size
,
const
vector
<
const
float
*>
&
per_channel_embeddings
,
float
*
embedding_output
)
=
0
;
// Extracts and returns the vector of LinkFeatures for the specified
// channel. Note: these are NOT translated.
virtual
std
::
vector
<
LinkFeatures
>
GetRawLinkFeatures
(
...
...
@@ -138,4 +148,4 @@ class Component : public RegisterableClass<Component> {
}
// namespace dragnn
}
// namespace syntaxnet
#endif //
NLP_SAFT_OPENSOURCE_
DRAGNN_CORE_INTERFACES_COMPONENT_H_
#endif // DRAGNN_CORE_INTERFACES_COMPONENT_H_
research/syntaxnet/dragnn/core/interfaces/input_batch.h
View file @
4364390a
...
...
@@ -13,8 +13,8 @@
// limitations under the License.
// =============================================================================
#ifndef
NLP_SAFT_OPENSOURCE_
DRAGNN_CORE_INTERFACES_INPUT_BATCH_H_
#define
NLP_SAFT_OPENSOURCE_
DRAGNN_CORE_INTERFACES_INPUT_BATCH_H_
#ifndef DRAGNN_CORE_INTERFACES_INPUT_BATCH_H_
#define DRAGNN_CORE_INTERFACES_INPUT_BATCH_H_
#include <string>
#include <vector>
...
...
@@ -32,14 +32,17 @@ class InputBatch {
public:
virtual
~
InputBatch
()
{}
// Set the data to translate to the subclass' data type.
// Set
s
the data to translate to the subclass' data type.
Call at most once.
virtual
void
SetData
(
const
std
::
vector
<
string
>
&
data
)
=
0
;
// Translate the underlying data back to a vector of strings, as appropriate.
// Returns the size of the batch.
virtual
int
GetSize
()
const
=
0
;
// Translates the underlying data back to a vector of strings, as appropriate.
virtual
const
std
::
vector
<
string
>
GetSerializedData
()
const
=
0
;
};
}
// namespace dragnn
}
// namespace syntaxnet
#endif //
NLP_SAFT_OPENSOURCE_
DRAGNN_CORE_INTERFACES_INPUT_BATCH_H_
#endif // DRAGNN_CORE_INTERFACES_INPUT_BATCH_H_
research/syntaxnet/dragnn/core/interfaces/transition_state.h
View file @
4364390a
...
...
@@ -13,8 +13,8 @@
// limitations under the License.
// =============================================================================
#ifndef
NLP_SAFT_OPENSOURCE_
DRAGNN_CORE_INTERFACES_TRANSITION_STATE_H_
#define
NLP_SAFT_OPENSOURCE_
DRAGNN_CORE_INTERFACES_TRANSITION_STATE_H_
#ifndef DRAGNN_CORE_INTERFACES_TRANSITION_STATE_H_
#define DRAGNN_CORE_INTERFACES_TRANSITION_STATE_H_
#include <memory>
#include <vector>
...
...
@@ -44,19 +44,25 @@ class TransitionState {
// Return the beam index of the state passed into the initializer of this
// TransitionState.
virtual
const
int
ParentBeamIndex
()
const
=
0
;
virtual
int
ParentBeamIndex
()
const
=
0
;
// Get the current beam index for this state.
virtual
const
int
GetBeamIndex
()
const
=
0
;
// Get
s
the current beam index for this state.
virtual
int
GetBeamIndex
()
const
=
0
;
// Set the current beam index for this state.
virtual
void
SetBeamIndex
(
const
int
index
)
=
0
;
// Set
s
the current beam index for this state.
virtual
void
SetBeamIndex
(
int
index
)
=
0
;
// Get the score associated with this transition state.
virtual
const
float
GetScore
()
const
=
0
;
// Get
s
the score associated with this transition state.
virtual
float
GetScore
()
const
=
0
;
// Set the score associated with this transition state.
virtual
void
SetScore
(
const
float
score
)
=
0
;
// Sets the score associated with this transition state.
virtual
void
SetScore
(
float
score
)
=
0
;
// Gets the gold-ness of this state (whether it is on the oracle path)
virtual
bool
IsGold
()
const
=
0
;
// Sets the gold-ness of this state.
virtual
void
SetGold
(
bool
is_gold
)
=
0
;
// Depicts this state as an HTML-language string.
virtual
string
HTMLRepresentation
()
const
=
0
;
...
...
@@ -65,4 +71,4 @@ class TransitionState {
}
// namespace dragnn
}
// namespace syntaxnet
#endif //
NLP_SAFT_OPENSOURCE_
DRAGNN_CORE_INTERFACES_TRANSITION_STATE_H_
#endif // DRAGNN_CORE_INTERFACES_TRANSITION_STATE_H_
research/syntaxnet/dragnn/core/ops/compute_session_op.h
View file @
4364390a
...
...
@@ -13,8 +13,8 @@
// limitations under the License.
// =============================================================================
#ifndef
NLP_SAFT_OPENSOURCE_
DRAGNN_CORE_OPS_COMPUTE_SESSION_OP_H_
#define
NLP_SAFT_OPENSOURCE_
DRAGNN_CORE_OPS_COMPUTE_SESSION_OP_H_
#ifndef DRAGNN_CORE_OPS_COMPUTE_SESSION_OP_H_
#define DRAGNN_CORE_OPS_COMPUTE_SESSION_OP_H_
#include <string>
...
...
@@ -66,4 +66,4 @@ class ComputeSessionOp : public tensorflow::OpKernel {
}
// namespace dragnn
}
// namespace syntaxnet
#endif //
NLP_SAFT_OPENSOURCE_
DRAGNN_CORE_OPS_COMPUTE_SESSION_OP_H_
#endif // DRAGNN_CORE_OPS_COMPUTE_SESSION_OP_H_
research/syntaxnet/dragnn/core/ops/dragnn_bulk_op_kernels.cc
View file @
4364390a
...
...
@@ -303,6 +303,73 @@ class BulkFixedEmbeddings : public ComputeSessionOp {
REGISTER_KERNEL_BUILDER
(
Name
(
"BulkFixedEmbeddings"
).
Device
(
DEVICE_CPU
),
BulkFixedEmbeddings
);
// See docstring in dragnn_bulk_ops.cc.
class
BulkEmbedFixedFeatures
:
public
ComputeSessionOp
{
public:
explicit
BulkEmbedFixedFeatures
(
OpKernelConstruction
*
context
)
:
ComputeSessionOp
(
context
)
{
OP_REQUIRES_OK
(
context
,
context
->
GetAttr
(
"num_channels"
,
&
num_channels_
));
// The input vector's zeroth element is the state handle, and the remaining
// num_channels_ elements are tensors of float embeddings, one per channel.
vector
<
DataType
>
input_types
(
num_channels_
+
1
,
DT_FLOAT
);
input_types
[
0
]
=
DT_STRING
;
const
vector
<
DataType
>
output_types
=
{
DT_STRING
,
DT_FLOAT
,
DT_INT32
};
OP_REQUIRES_OK
(
context
,
context
->
MatchSignature
(
input_types
,
output_types
));
OP_REQUIRES_OK
(
context
,
context
->
GetAttr
(
"pad_to_batch"
,
&
pad_to_batch_
));
OP_REQUIRES_OK
(
context
,
context
->
GetAttr
(
"pad_to_steps"
,
&
pad_to_steps_
));
}
bool
OutputsHandle
()
const
override
{
return
true
;
}
bool
RequiresComponentName
()
const
override
{
return
true
;
}
void
ComputeWithState
(
OpKernelContext
*
context
,
ComputeSession
*
session
)
override
{
const
auto
&
spec
=
session
->
Spec
(
component_name
());
int
embedding_size
=
0
;
std
::
vector
<
const
float
*>
embeddings
(
num_channels_
);
for
(
int
channel
=
0
;
channel
<
num_channels_
;
++
channel
)
{
const
int
embeddings_index
=
channel
+
1
;
embedding_size
+=
context
->
input
(
embeddings_index
).
shape
().
dim_size
(
1
)
*
spec
.
fixed_feature
(
channel
).
size
();
embeddings
[
channel
]
=
context
->
input
(
embeddings_index
).
flat
<
float
>
().
data
();
}
Tensor
*
embedding_vectors
;
OP_REQUIRES_OK
(
context
,
context
->
allocate_output
(
1
,
TensorShape
({
pad_to_steps_
*
pad_to_batch_
*
session
->
BeamSize
(
component_name
()),
embedding_size
}),
&
embedding_vectors
));
Tensor
*
num_steps_tensor
;
OP_REQUIRES_OK
(
context
,
context
->
allocate_output
(
2
,
TensorShape
({}),
&
num_steps_tensor
));
embedding_vectors
->
flat
<
float
>
().
setZero
();
int
output_size
=
embedding_vectors
->
NumElements
();
session
->
BulkEmbedFixedFeatures
(
component_name
(),
pad_to_batch_
,
pad_to_steps_
,
output_size
,
embeddings
,
embedding_vectors
->
flat
<
float
>
().
data
());
num_steps_tensor
->
scalar
<
int32
>
()()
=
pad_to_steps_
;
}
private:
// Number of fixed feature channels.
int
num_channels_
;
// Will pad output to this many batch elements.
int
pad_to_batch_
;
// Will pad output to this many steps.
int
pad_to_steps_
;
TF_DISALLOW_COPY_AND_ASSIGN
(
BulkEmbedFixedFeatures
);
};
REGISTER_KERNEL_BUILDER
(
Name
(
"BulkEmbedFixedFeatures"
).
Device
(
DEVICE_CPU
),
BulkEmbedFixedFeatures
);
// See docstring in dragnn_bulk_ops.cc.
class
BulkAdvanceFromOracle
:
public
ComputeSessionOp
{
public:
...
...
@@ -387,8 +454,11 @@ class BulkAdvanceFromPrediction : public ComputeSessionOp {
}
}
if
(
!
session
->
IsTerminal
(
component_name
()))
{
session
->
AdvanceFromPrediction
(
component_name
(),
scores_per_step
.
data
(),
scores_per_step
.
size
());
bool
success
=
session
->
AdvanceFromPrediction
(
component_name
(),
scores_per_step
.
data
(),
num_items
,
num_actions
);
OP_REQUIRES
(
context
,
success
,
tensorflow
::
errors
::
Internal
(
"Unable to advance from prediction."
));
}
}
}
...
...
research/syntaxnet/dragnn/core/ops/dragnn_bulk_op_kernels_test.cc
View file @
4364390a
...
...
@@ -375,6 +375,114 @@ TEST_F(DragnnBulkOpKernelsTest, BulkFixedEmbeddings) {
EXPECT_EQ
(
kNumSteps
,
GetOutput
(
2
)
->
scalar
<
int32
>
()());
}
TEST_F
(
DragnnBulkOpKernelsTest
,
BulkEmbedFixedFeatures
)
{
// Create and initialize the kernel under test.
constexpr
int
kBatchPad
=
7
;
constexpr
int
kStepPad
=
5
;
constexpr
int
kMaxBeamSize
=
3
;
TF_ASSERT_OK
(
NodeDefBuilder
(
"BulkEmbedFixedFeatures"
,
"BulkEmbedFixedFeatures"
)
.
Attr
(
"component"
,
kComponentName
)
.
Attr
(
"num_channels"
,
kNumChannels
)
.
Attr
(
"pad_to_batch"
,
kBatchPad
)
.
Attr
(
"pad_to_steps"
,
kStepPad
)
.
Input
(
FakeInput
(
DT_STRING
))
// The handle for the ComputeSession.
.
Input
(
FakeInput
(
DT_FLOAT
))
// Embedding matrices.
.
Finalize
(
node_def
()));
MockComputeSession
*
mock_session
=
GetMockSession
();
ComponentSpec
spec
;
spec
.
set_name
(
kComponentName
);
auto
chan0_spec
=
spec
.
add_fixed_feature
();
constexpr
int
kChan0FeatureCount
=
2
;
chan0_spec
->
set_size
(
kChan0FeatureCount
);
auto
chan1_spec
=
spec
.
add_fixed_feature
();
constexpr
int
kChan1FeatureCount
=
1
;
chan1_spec
->
set_size
(
kChan1FeatureCount
);
EXPECT_CALL
(
*
mock_session
,
Spec
(
kComponentName
))
.
WillOnce
(
testing
::
ReturnRef
(
spec
));
EXPECT_CALL
(
*
mock_session
,
BeamSize
(
kComponentName
))
.
WillOnce
(
testing
::
Return
(
kMaxBeamSize
));
// Embedding matrices as additional inputs.
// For channel 0, the embeddings are [id, id, id].
// For channel 1, the embeddings are [id^2, id^2, id^2, ... ,id^2].
vector
<
float
>
embedding_matrix_0
;
constexpr
int
kEmbedding0Size
=
3
;
vector
<
float
>
embedding_matrix_1
;
constexpr
int
kEmbedding1Size
=
9
;
for
(
int
id
=
0
;
id
<
kNumIds
;
++
id
)
{
for
(
int
i
=
0
;
i
<
kEmbedding0Size
;
++
i
)
{
embedding_matrix_0
.
push_back
(
id
);
LOG
(
INFO
)
<<
embedding_matrix_0
.
back
();
}
for
(
int
i
=
0
;
i
<
kEmbedding1Size
;
++
i
)
{
embedding_matrix_1
.
push_back
(
id
*
id
);
LOG
(
INFO
)
<<
embedding_matrix_0
.
back
();
}
}
AddInputFromArray
<
float
>
(
TensorShape
({
kNumIds
,
kEmbedding0Size
}),
embedding_matrix_0
);
AddInputFromArray
<
float
>
(
TensorShape
({
kNumIds
,
kEmbedding1Size
}),
embedding_matrix_1
);
constexpr
int
kExpectedEmbeddingSize
=
kChan0FeatureCount
*
kEmbedding0Size
+
kChan1FeatureCount
*
kEmbedding1Size
;
constexpr
int
kExpectedOutputSize
=
kExpectedEmbeddingSize
*
kBatchPad
*
kStepPad
*
kMaxBeamSize
;
// This function takes the allocator functions passed into GetBulkFF, uses
// them to allocate a tensor, then fills that tensor based on channel.
auto
eval_function
=
[
=
](
const
string
&
component_name
,
int
batch_size_padding
,
int
num_steps_padding
,
int
output_array_size
,
const
vector
<
const
float
*>
&
per_channel_embeddings
,
float
*
embedding_output
)
{
// Validate the control variables.
EXPECT_EQ
(
batch_size_padding
,
kBatchPad
);
EXPECT_EQ
(
num_steps_padding
,
kStepPad
);
EXPECT_EQ
(
output_array_size
,
kExpectedOutputSize
);
// Validate the passed embeddings.
for
(
int
i
=
0
;
i
<
kNumIds
;
++
i
)
{
for
(
int
j
=
0
;
j
<
kEmbedding0Size
;
++
j
)
{
float
ch0_embedding
=
per_channel_embeddings
.
at
(
0
)[
i
*
kEmbedding0Size
+
j
];
EXPECT_FLOAT_EQ
(
ch0_embedding
,
i
)
<<
"Failed match at "
<<
i
<<
","
<<
j
;
}
for
(
int
j
=
0
;
j
<
kEmbedding1Size
;
++
j
)
{
float
ch1_embedding
=
per_channel_embeddings
.
at
(
1
)[
i
*
kEmbedding1Size
+
j
];
EXPECT_FLOAT_EQ
(
ch1_embedding
,
i
*
i
)
<<
"Failed match at "
<<
i
<<
","
<<
j
;
}
}
// Fill the output matrix to the expected size. This will trigger msan
// if the allocation wasn't big enough.
for
(
int
i
=
0
;
i
<
kExpectedOutputSize
;
++
i
)
{
embedding_output
[
i
]
=
i
;
}
};
EXPECT_CALL
(
*
mock_session
,
BulkEmbedFixedFeatures
(
kComponentName
,
_
,
_
,
_
,
_
,
_
))
.
WillOnce
(
testing
::
Invoke
(
eval_function
));
// Run the kernel.
TF_EXPECT_OK
(
RunOpKernelWithContext
());
// Validate outputs.
EXPECT_EQ
(
kBatchPad
*
kStepPad
*
kMaxBeamSize
,
GetOutput
(
1
)
->
shape
().
dim_size
(
0
));
EXPECT_EQ
(
kExpectedEmbeddingSize
,
GetOutput
(
1
)
->
shape
().
dim_size
(
1
));
auto
output_data
=
GetOutput
(
1
)
->
flat
<
float
>
();
for
(
int
i
=
0
;
i
<
kExpectedOutputSize
;
++
i
)
{
EXPECT_FLOAT_EQ
(
i
,
output_data
(
i
));
}
EXPECT_EQ
(
kStepPad
,
GetOutput
(
2
)
->
scalar
<
int32
>
()());
}
TEST_F
(
DragnnBulkOpKernelsTest
,
BulkFixedEmbeddingsWithPadding
)
{
// Create and initialize the kernel under test.
constexpr
int
kPaddedNumSteps
=
5
;
...
...
@@ -592,12 +700,54 @@ TEST_F(DragnnBulkOpKernelsTest, BulkAdvanceFromPrediction) {
EXPECT_CALL
(
*
mock_session
,
AdvanceFromPrediction
(
kComponentName
,
CheckScoresAreConsecutiveIntegersDivTen
(),
kNumItems
*
kNumActions
))
.
Times
(
kNumSteps
);
kNumItems
,
kNumActions
))
.
Times
(
kNumSteps
)
.
WillRepeatedly
(
Return
(
true
));
// Run the kernel.
TF_EXPECT_OK
(
RunOpKernelWithContext
());
}
TEST_F
(
DragnnBulkOpKernelsTest
,
BulkAdvanceFromPredictionFailsIfAdvanceFails
)
{
// Create and initialize the kernel under test.
TF_ASSERT_OK
(
NodeDefBuilder
(
"BulkAdvanceFromPrediction"
,
"BulkAdvanceFromPrediction"
)
.
Attr
(
"component"
,
kComponentName
)
.
Input
(
FakeInput
(
DT_STRING
))
// The handle for the ComputeSession.
.
Input
(
FakeInput
(
DT_FLOAT
))
// Prediction scores for advancing.
.
Finalize
(
node_def
()));
MockComputeSession
*
mock_session
=
GetMockSession
();
// Creates an input tensor such that each step will see a list of consecutive
// integers divided by 10 as scores.
vector
<
float
>
scores
(
kNumItems
*
kNumSteps
*
kNumActions
);
for
(
int
step
(
0
),
cnt
(
0
);
step
<
kNumSteps
;
++
step
)
{
for
(
int
item
=
0
;
item
<
kNumItems
;
++
item
)
{
for
(
int
action
=
0
;
action
<
kNumActions
;
++
action
,
++
cnt
)
{
scores
[
action
+
kNumActions
*
(
step
+
item
*
kNumSteps
)]
=
cnt
/
10.0
f
;
}
}
}
AddInputFromArray
<
float
>
(
TensorShape
({
kNumItems
*
kNumSteps
,
kNumActions
}),
scores
);
EXPECT_CALL
(
*
mock_session
,
BeamSize
(
kComponentName
)).
WillOnce
(
Return
(
1
));
EXPECT_CALL
(
*
mock_session
,
BatchSize
(
kComponentName
))
.
WillOnce
(
Return
(
kNumItems
));
EXPECT_CALL
(
*
mock_session
,
IsTerminal
(
kComponentName
))
.
Times
(
2
)
.
WillRepeatedly
(
Return
(
false
));
EXPECT_CALL
(
*
mock_session
,
AdvanceFromPrediction
(
kComponentName
,
CheckScoresAreConsecutiveIntegersDivTen
(),
kNumItems
,
kNumActions
))
.
WillOnce
(
Return
(
true
))
.
WillOnce
(
Return
(
false
));
// Run the kernel.
auto
result
=
RunOpKernelWithContext
();
EXPECT_FALSE
(
result
.
ok
());
}
}
// namespace dragnn
}
// namespace syntaxnet
Prev
1
2
3
4
5
6
…
9
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment