Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
4364390a
Commit
4364390a
authored
Nov 13, 2017
by
Ivan Bogatyy
Committed by
calberti
Nov 13, 2017
Browse files
Release DRAGNN bulk networks (#2785)
* Release DRAGNN bulk networks
parent
638fd759
Changes
166
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1295 additions
and
261 deletions
+1295
-261
research/syntaxnet/dragnn/core/BUILD
research/syntaxnet/dragnn/core/BUILD
+5
-1
research/syntaxnet/dragnn/core/beam.h
research/syntaxnet/dragnn/core/beam.h
+197
-92
research/syntaxnet/dragnn/core/beam_test.cc
research/syntaxnet/dragnn/core/beam_test.cc
+551
-70
research/syntaxnet/dragnn/core/component_registry.h
research/syntaxnet/dragnn/core/component_registry.h
+8
-3
research/syntaxnet/dragnn/core/compute_session.h
research/syntaxnet/dragnn/core/compute_session.h
+27
-7
research/syntaxnet/dragnn/core/compute_session_impl.cc
research/syntaxnet/dragnn/core/compute_session_impl.cc
+21
-5
research/syntaxnet/dragnn/core/compute_session_impl.h
research/syntaxnet/dragnn/core/compute_session_impl.h
+18
-10
research/syntaxnet/dragnn/core/compute_session_impl_test.cc
research/syntaxnet/dragnn/core/compute_session_impl_test.cc
+73
-14
research/syntaxnet/dragnn/core/compute_session_pool.h
research/syntaxnet/dragnn/core/compute_session_pool.h
+11
-5
research/syntaxnet/dragnn/core/compute_session_pool_test.cc
research/syntaxnet/dragnn/core/compute_session_pool_test.cc
+1
-0
research/syntaxnet/dragnn/core/index_translator.h
research/syntaxnet/dragnn/core/index_translator.h
+3
-3
research/syntaxnet/dragnn/core/input_batch_cache.h
research/syntaxnet/dragnn/core/input_batch_cache.h
+38
-6
research/syntaxnet/dragnn/core/input_batch_cache_test.cc
research/syntaxnet/dragnn/core/input_batch_cache_test.cc
+52
-0
research/syntaxnet/dragnn/core/interfaces/cloneable_transition_state.h
...axnet/dragnn/core/interfaces/cloneable_transition_state.h
+20
-14
research/syntaxnet/dragnn/core/interfaces/component.h
research/syntaxnet/dragnn/core/interfaces/component.h
+17
-7
research/syntaxnet/dragnn/core/interfaces/input_batch.h
research/syntaxnet/dragnn/core/interfaces/input_batch.h
+8
-5
research/syntaxnet/dragnn/core/interfaces/transition_state.h
research/syntaxnet/dragnn/core/interfaces/transition_state.h
+18
-12
research/syntaxnet/dragnn/core/ops/compute_session_op.h
research/syntaxnet/dragnn/core/ops/compute_session_op.h
+3
-3
research/syntaxnet/dragnn/core/ops/dragnn_bulk_op_kernels.cc
research/syntaxnet/dragnn/core/ops/dragnn_bulk_op_kernels.cc
+72
-2
research/syntaxnet/dragnn/core/ops/dragnn_bulk_op_kernels_test.cc
.../syntaxnet/dragnn/core/ops/dragnn_bulk_op_kernels_test.cc
+152
-2
No files found.
research/syntaxnet/dragnn/core/BUILD
View file @
4364390a
...
...
@@ -33,8 +33,9 @@ cc_library(
name
=
"compute_session"
,
hdrs
=
[
"compute_session.h"
],
deps
=
[
":index_translator"
,
":input_batch_cache"
,
"//dragnn/components/util:bulk_feature_extractor"
,
"//dragnn/core:index_translator"
,
"//dragnn/core/interfaces:component"
,
"//dragnn/protos:spec_proto"
,
"//dragnn/protos:trace_proto"
,
...
...
@@ -120,8 +121,10 @@ cc_test(
":compute_session"
,
":compute_session_impl"
,
":compute_session_pool"
,
":input_batch_cache"
,
"//dragnn/components/util:bulk_feature_extractor"
,
"//dragnn/core/interfaces:component"
,
"//dragnn/core/interfaces:input_batch"
,
"//dragnn/core/test:generic"
,
"//dragnn/core/test:mock_component"
,
"//dragnn/core/test:mock_transition_state"
,
...
...
@@ -248,6 +251,7 @@ cc_library(
"//syntaxnet:base"
,
"@org_tensorflow//third_party/eigen3"
,
],
alwayslink
=
1
,
)
# Tensorflow kernel libraries, for use with unit tests.
...
...
research/syntaxnet/dragnn/core/beam.h
View file @
4364390a
...
...
@@ -13,8 +13,8 @@
// limitations under the License.
// =============================================================================
#ifndef
NLP_SAFT_OPENSOURCE_
DRAGNN_CORE_BEAM_H_
#define
NLP_SAFT_OPENSOURCE_
DRAGNN_CORE_BEAM_H_
#ifndef DRAGNN_CORE_BEAM_H_
#define DRAGNN_CORE_BEAM_H_
#include <algorithm>
#include <cmath>
...
...
@@ -43,19 +43,23 @@ class Beam {
static_assert
(
std
::
is_base_of
<
CloneableTransitionState
<
T
>
,
T
>::
value
,
"This class must be instantiated to use a CloneableTransitionState"
);
track_gold_
=
false
;
}
// Sets whether or not the beam should track gold states.
void
SetGoldTracking
(
bool
track_gold
)
{
track_gold_
=
track_gold
;
}
// Sets the Beam functions, as follows:
// bool is_allowed(TransitionState *, int): Return true if transition 'int' is
// allowed for transition state 'TransitionState *'.
// void perform_transition(TransitionState *, int): Performs transition 'int'
// on transition state 'TransitionState *'.
// int oracle_function(TransitionState *): Returns the oracle-
specified action
// for transition state 'TransitionState *'.
//
vector<
int
>
oracle_function(TransitionState *): Returns the oracle-
//
specified actions
for transition state 'TransitionState *'.
void
SetFunctions
(
std
::
function
<
bool
(
T
*
,
int
)
>
is_allowed
,
std
::
function
<
bool
(
T
*
)
>
is_final
,
std
::
function
<
void
(
T
*
,
int
)
>
perform_transition
,
std
::
function
<
int
(
T
*
)
>
oracle_function
)
{
std
::
function
<
vector
<
int
>
(
T
*
)
>
oracle_function
)
{
is_allowed_
=
is_allowed
;
is_final_
=
is_final
;
perform_transition_
=
perform_transition
;
...
...
@@ -74,12 +78,17 @@ class Beam {
for
(
int
i
=
0
;
i
<
beam_
.
size
();
++
i
)
{
previous_beam_indices
.
at
(
i
)
=
beam_
[
i
]
->
ParentBeamIndex
();
beam_
[
i
]
->
SetBeamIndex
(
i
);
// TODO(googleuser): Add gold tracking to component-level state creation.
if
(
!
track_gold_
)
{
beam_
[
i
]
->
SetGold
(
false
);
}
}
beam_index_history_
.
emplace_back
(
previous_beam_indices
);
}
// Advances the Beam from the given transition matrix.
void
AdvanceFromPrediction
(
const
float
transition_matrix
[]
,
int
matrix_length
,
bool
AdvanceFromPrediction
(
const
float
*
transition_matrix
,
int
matrix_length
,
int
num_actions
)
{
// Ensure that the transition matrix is the correct size. All underlying
// states should have the same transition profile, so using the one at 0
...
...
@@ -89,91 +98,20 @@ class Beam {
"state transitions!"
;
if
(
max_size_
==
1
)
{
// In the case where beam size is 1, we can advance by simply finding the
// highest score and advancing the beam state in place.
VLOG
(
2
)
<<
"Beam size is 1. Using fast beam path."
;
int
best_action
=
-
1
;
float
best_score
=
-
INFINITY
;
auto
&
state
=
beam_
[
0
];
for
(
int
action_idx
=
0
;
action_idx
<
num_actions
;
++
action_idx
)
{
if
(
is_allowed_
(
state
.
get
(),
action_idx
)
&&
transition_matrix
[
action_idx
]
>
best_score
)
{
best_score
=
transition_matrix
[
action_idx
];
best_action
=
action_idx
;
}
bool
success
=
FastAdvanceFromPrediction
(
transition_matrix
,
num_actions
);
if
(
!
success
)
{
return
false
;
}
CHECK_GE
(
best_action
,
0
)
<<
"Num actions: "
<<
num_actions
<<
" score[0]: "
<<
transition_matrix
[
0
];
perform_transition_
(
state
.
get
(),
best_action
);
const
float
new_score
=
state
->
GetScore
()
+
best_score
;
state
->
SetScore
(
new_score
);
state
->
SetBeamIndex
(
0
);
}
else
{
// Create the vector of all possible transitions, along with their scores.
std
::
vector
<
Transition
>
candidates
;
// Iterate through all beams, examining all actions for each beam.
for
(
int
beam_idx
=
0
;
beam_idx
<
beam_
.
size
();
++
beam_idx
)
{
const
auto
&
state
=
beam_
[
beam_idx
];
for
(
int
action_idx
=
0
;
action_idx
<
num_actions
;
++
action_idx
)
{
// If the action is allowed, calculate the proposed new score and add
// the candidate action to the vector of all actions at this state.
if
(
is_allowed_
(
state
.
get
(),
action_idx
))
{
Transition
candidate
;
// The matrix is laid out by beam index, with a linear set of
// actions for that index - so beam N's actions start at [nr. of
// actions]*[N].
const
int
matrix_idx
=
action_idx
+
beam_idx
*
num_actions
;
CHECK_LT
(
matrix_idx
,
matrix_length
)
<<
"Matrix index out of bounds!"
;
const
double
score_delta
=
transition_matrix
[
matrix_idx
];
CHECK
(
!
std
::
isnan
(
score_delta
));
candidate
.
source_idx
=
beam_idx
;
candidate
.
action
=
action_idx
;
candidate
.
resulting_score
=
state
->
GetScore
()
+
score_delta
;
candidates
.
emplace_back
(
candidate
);
}
}
}
// Sort the vector of all possible transitions and scores.
const
auto
comparator
=
[](
const
Transition
&
a
,
const
Transition
&
b
)
{
return
a
.
resulting_score
>
b
.
resulting_score
;
};
std
::
stable_sort
(
candidates
.
begin
(),
candidates
.
end
(),
comparator
);
// Apply the top transitions, up to a maximum of 'max_size_'.
std
::
vector
<
std
::
unique_ptr
<
T
>>
new_beam
;
std
::
vector
<
int
>
previous_beam_indices
(
max_size_
,
-
1
);
const
int
beam_size
=
std
::
min
(
max_size_
,
static_cast
<
int
>
(
candidates
.
size
()));
VLOG
(
2
)
<<
"Previous beam size = "
<<
beam_
.
size
();
VLOG
(
2
)
<<
"New beam size = "
<<
beam_size
;
VLOG
(
2
)
<<
"Maximum beam size = "
<<
max_size_
;
for
(
int
i
=
0
;
i
<
beam_size
;
++
i
)
{
// Get the source of the i'th transition.
const
auto
&
transition
=
candidates
[
i
];
VLOG
(
2
)
<<
"Taking transition with score: "
<<
transition
.
resulting_score
<<
" and action: "
<<
transition
.
action
;
VLOG
(
2
)
<<
"transition.source_idx = "
<<
transition
.
source_idx
;
const
auto
&
source
=
beam_
[
transition
.
source_idx
];
// Put the new transition on the new state beam.
auto
new_state
=
source
->
Clone
();
perform_transition_
(
new_state
.
get
(),
transition
.
action
);
new_state
->
SetScore
(
transition
.
resulting_score
);
new_state
->
SetBeamIndex
(
i
);
previous_beam_indices
.
at
(
i
)
=
transition
.
source_idx
;
new_beam
.
emplace_back
(
std
::
move
(
new_state
));
bool
success
=
BeamAdvanceFromPrediction
(
transition_matrix
,
matrix_length
,
num_actions
);
if
(
!
success
)
{
return
false
;
}
beam_
=
std
::
move
(
new_beam
);
beam_index_history_
.
emplace_back
(
previous_beam_indices
);
}
++
num_steps_
;
return
true
;
}
// Advances the Beam from the state oracles.
...
...
@@ -182,7 +120,10 @@ class Beam {
for
(
int
i
=
0
;
i
<
beam_
.
size
();
++
i
)
{
previous_beam_indices
.
at
(
i
)
=
i
;
if
(
is_final_
(
beam_
[
i
].
get
()))
continue
;
const
auto
oracle_label
=
oracle_function_
(
beam_
[
i
].
get
());
// There will always be at least one oracular transition, and taking the
// first returned transition is never worse than any other option.
const
int
oracle_label
=
oracle_function_
(
beam_
[
i
].
get
()).
at
(
0
);
VLOG
(
2
)
<<
"AdvanceFromOracle beam_index:"
<<
i
<<
" oracle_label:"
<<
oracle_label
;
perform_transition_
(
beam_
[
i
].
get
(),
oracle_label
);
...
...
@@ -312,19 +253,180 @@ class Beam {
// Returns the current size of the beam.
const
int
size
()
const
{
return
beam_
.
size
();
}
// Returns true if at least one of the states in the beam is gold.
bool
ContainsGold
()
{
if
(
!
track_gold_
)
{
return
false
;
}
for
(
const
auto
&
state
:
beam_
)
{
if
(
state
->
IsGold
())
{
return
true
;
}
}
return
false
;
}
private:
// Associates an action taken on an index into current_state_ with a score.
friend
void
BM_FastAdvance
(
int
num_iters
,
int
num_transitions
);
friend
void
BM_BeamAdvance
(
int
num_iters
,
int
num_transitions
,
int
max_beam_size
);
// Associates an action taken with its source index.
struct
Transition
{
// The index of the source item.
int
source_idx
;
// The index of the action being taken.
int
action
;
// The score of the full derivation.
double
resulting_score
;
};
// In the case where beam size is 1, we can advance by simply finding the
// highest score and advancing the beam state in place.
bool
FastAdvanceFromPrediction
(
const
float
*
transition_matrix
,
int
num_actions
)
{
CHECK_EQ
(
1
,
max_size_
)
<<
"Using fast advance on invalid beam. This should never happen."
;
VLOG
(
2
)
<<
"Beam size is 1. Using fast beam path."
;
constexpr
int
kNoActionChosen
=
-
1
;
int
best_action
=
kNoActionChosen
;
float
best_score
=
-
INFINITY
;
auto
&
state
=
beam_
[
0
];
for
(
int
action_idx
=
0
;
action_idx
<
num_actions
;
++
action_idx
)
{
if
(
std
::
isnan
(
transition_matrix
[
action_idx
]))
{
LOG
(
ERROR
)
<<
"Found a NaN in the transition matrix! Unable to "
"continue. Num actions: "
<<
num_actions
<<
" index: "
<<
action_idx
;
return
false
;
}
if
(
is_allowed_
(
state
.
get
(),
action_idx
)
&&
transition_matrix
[
action_idx
]
>
best_score
)
{
best_score
=
transition_matrix
[
action_idx
];
best_action
=
action_idx
;
}
}
if
(
best_action
==
kNoActionChosen
)
{
LOG
(
ERROR
)
<<
"No action was chosen! Unable to continue. Num actions: "
<<
num_actions
<<
" score[0]: "
<<
transition_matrix
[
0
];
return
false
;
}
bool
is_gold
=
false
;
if
(
track_gold_
&&
state
->
IsGold
())
{
for
(
const
auto
&
gold_transition
:
oracle_function_
(
state
.
get
()))
{
VLOG
(
3
)
<<
"Examining gold transition "
<<
gold_transition
<<
" for source index 1"
;
if
(
gold_transition
==
best_action
)
{
is_gold
=
true
;
break
;
}
}
}
perform_transition_
(
state
.
get
(),
best_action
);
const
float
new_score
=
state
->
GetScore
()
+
best_score
;
state
->
SetScore
(
new_score
);
state
->
SetBeamIndex
(
0
);
state
->
SetGold
(
is_gold
);
return
true
;
}
// In case the beam size is greater than 1, we need to advance using
// standard beam search.
bool
BeamAdvanceFromPrediction
(
const
float
*
transition_matrix
,
int
matrix_length
,
int
num_actions
)
{
VLOG
(
2
)
<<
"Beam size is "
<<
max_size_
<<
". Using standard beam search."
;
// Keep the multimap sorted high to low. The sort order for
// identical keys is stable.
std
::
multimap
<
float
,
Transition
,
std
::
greater
<
float
>>
candidates
;
float
threshold
=
-
INFINITY
;
// Iterate through all beams, examining all actions for each beam.
for
(
int
beam_idx
=
0
;
beam_idx
<
beam_
.
size
();
++
beam_idx
)
{
const
auto
&
state
=
beam_
[
beam_idx
];
const
float
score
=
state
->
GetScore
();
for
(
int
action_idx
=
0
;
action_idx
<
num_actions
;
++
action_idx
)
{
if
(
is_allowed_
(
state
.
get
(),
action_idx
))
{
// The matrix is laid out by beam index, with a linear set of
// actions for that index - so beam N's actions start at [nr. of
// actions]*[N].
const
int
matrix_idx
=
action_idx
+
beam_idx
*
num_actions
;
CHECK_LT
(
matrix_idx
,
matrix_length
)
<<
"Matrix index out of bounds!"
;
const
float
resulting_score
=
score
+
transition_matrix
[
matrix_idx
];
if
(
std
::
isnan
(
resulting_score
))
{
LOG
(
ERROR
)
<<
"Resulting score was a NaN! Unable to continue. Num "
"actions: "
<<
num_actions
<<
" action index "
<<
action_idx
;
return
false
;
}
if
(
candidates
.
size
()
==
max_size_
)
{
// If the new score is lower than the bottom of the beam, move on.
if
(
resulting_score
<
threshold
)
{
continue
;
}
// Otherwise, remove the bottom of the beam, making space
// for the new candidate.
candidates
.
erase
(
std
::
prev
(
candidates
.
end
()));
}
// Add the new candidate, and update the threshold score.
const
Transition
candidate
{
beam_idx
,
action_idx
};
candidates
.
emplace
(
resulting_score
,
candidate
);
threshold
=
candidates
.
rbegin
()
->
first
;
}
}
}
// Apply the top transitions, up to a maximum of 'max_size_'.
std
::
vector
<
std
::
unique_ptr
<
T
>>
new_beam
;
std
::
vector
<
int
>
previous_beam_indices
(
max_size_
,
-
1
);
const
int
beam_size
=
candidates
.
size
();
new_beam
.
reserve
(
max_size_
);
VLOG
(
2
)
<<
"Previous beam size = "
<<
beam_
.
size
();
VLOG
(
2
)
<<
"New beam size = "
<<
beam_size
;
VLOG
(
2
)
<<
"Maximum beam size = "
<<
max_size_
;
auto
candidate_iterator
=
candidates
.
cbegin
();
for
(
int
i
=
0
;
i
<
beam_size
;
++
i
)
{
// Get the score and source of the i'th transition.
const
float
resulting_score
=
candidate_iterator
->
first
;
const
auto
&
transition
=
candidate_iterator
->
second
;
++
candidate_iterator
;
VLOG
(
2
)
<<
"Taking transition with score: "
<<
resulting_score
<<
" and action: "
<<
transition
.
action
;
VLOG
(
2
)
<<
"transition.source_idx = "
<<
transition
.
source_idx
;
const
auto
&
source
=
beam_
[
transition
.
source_idx
];
// Determine if the transition being taken will result in a gold state.
bool
is_gold
=
false
;
if
(
track_gold_
&&
source
->
IsGold
())
{
for
(
const
auto
&
gold_transition
:
oracle_function_
(
source
.
get
()))
{
VLOG
(
3
)
<<
"Examining gold transition "
<<
gold_transition
<<
" for source index "
<<
transition
.
source_idx
;
if
(
gold_transition
==
transition
.
action
)
{
VLOG
(
2
)
<<
"State from index "
<<
transition
.
source_idx
<<
" is gold."
;
is_gold
=
true
;
break
;
}
}
}
VLOG
(
2
)
<<
"Gold examination complete for source index "
<<
transition
.
source_idx
;
// Put the new transition on the new state beam.
auto
new_state
=
source
->
Clone
();
perform_transition_
(
new_state
.
get
(),
transition
.
action
);
new_state
->
SetScore
(
resulting_score
);
new_state
->
SetBeamIndex
(
i
);
new_state
->
SetGold
(
is_gold
);
previous_beam_indices
.
at
(
i
)
=
transition
.
source_idx
;
new_beam
.
emplace_back
(
std
::
move
(
new_state
));
}
beam_
=
std
::
move
(
new_beam
);
beam_index_history_
.
emplace_back
(
previous_beam_indices
);
return
true
;
}
// The maximum beam size.
int
max_size_
;
...
...
@@ -341,7 +443,7 @@ class Beam {
std
::
function
<
void
(
T
*
,
int
)
>
perform_transition_
;
// Function to provide the oracle action for a given state.
std
::
function
<
int
(
T
*
)
>
oracle_function_
;
std
::
function
<
vector
<
int
>
(
T
*
)
>
oracle_function_
;
// The history of the states in this beam. The vector indexes across steps.
// For every step, there is a vector in the vector. This inner vector denotes
...
...
@@ -355,9 +457,12 @@ class Beam {
// The number of steps taken so far.
int
num_steps_
;
// Whether to track golden states.
bool
track_gold_
;
};
}
// namespace dragnn
}
// namespace syntaxnet
#endif //
NLP_SAFT_OPENSOURCE_
DRAGNN_CORE_BEAM_H_
#endif // DRAGNN_CORE_BEAM_H_
research/syntaxnet/dragnn/core/beam_test.cc
View file @
4364390a
...
...
@@ -15,11 +15,15 @@
#include "dragnn/core/beam.h"
#include <limits>
#include <random>
#include "dragnn/core/interfaces/cloneable_transition_state.h"
#include "dragnn/core/interfaces/transition_state.h"
#include "dragnn/core/test/mock_transition_state.h"
#include <gmock/gmock.h>
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/test_benchmark.h"
namespace
syntaxnet
{
namespace
dragnn
{
...
...
@@ -43,7 +47,7 @@ namespace {
class
TestTransitionState
:
public
CloneableTransitionState
<
TestTransitionState
>
{
public:
TestTransitionState
()
{}
TestTransitionState
()
:
is_gold_
(
false
)
{}
void
Init
(
const
TransitionState
&
parent
)
override
{}
...
...
@@ -52,19 +56,25 @@ class TestTransitionState
return
ptr
;
}
const
int
ParentBeamIndex
()
const
override
{
return
parent_beam_index_
;
}
int
ParentBeamIndex
()
const
override
{
return
parent_beam_index_
;
}
// Gets the current beam index for this state.
int
GetBeamIndex
()
const
override
{
return
beam_index_
;
}
//
G
et the current beam index for this state.
const
int
G
etBeamIndex
(
)
const
override
{
return
beam_index_
;
}
//
S
et
s
the current beam index for this state.
void
S
etBeamIndex
(
int
index
)
override
{
beam_index_
=
index
;
}
//
S
et the
current beam index for this
state.
void
SetBeamIndex
(
const
int
index
)
override
{
beam_index_
=
index
;
}
//
G
et
s
the
score associated with this transition
state.
float
GetScore
()
const
override
{
return
score_
;
}
//
G
et the score associated with this transition state.
const
float
GetS
core
(
)
const
override
{
return
score
_
;
}
//
S
et
s
the score associated with this transition state.
void
SetScore
(
float
s
core
)
override
{
score_
=
score
;
}
// Set the score associated with this transition state.
void
SetScore
(
const
float
score
)
override
{
score_
=
score
;
}
// Gets the gold-ness of this state (whether it is on the oracle path)
bool
IsGold
()
const
override
{
return
is_gold_
;
}
// Sets the gold-ness of this state.
void
SetGold
(
bool
is_gold
)
override
{
is_gold_
=
is_gold
;
}
// Depicts this state as an HTML-language string.
string
HTMLRepresentation
()
const
override
{
return
""
;
}
...
...
@@ -76,6 +86,8 @@ class TestTransitionState
float
score_
;
int
transition_action_
;
bool
is_gold_
;
};
// This transition function annotates a TestTransitionState with the action that
...
...
@@ -85,12 +97,14 @@ auto transition_function = [](TestTransitionState *state, int action) {
cast_state
->
transition_action_
=
action
;
};
// Create oracle and permission functions that do nothing.
auto
null_oracle
=
[](
TestTransitionState
*
)
{
return
0
;
};
// Creates oracle and permission functions that do nothing.
auto
null_oracle
=
[](
TestTransitionState
*
)
->
const
vector
<
int
>
{
return
{
0
};
};
auto
null_permissions
=
[](
TestTransitionState
*
,
int
)
{
return
true
;
};
auto
null_finality
=
[](
TestTransitionState
*
)
{
return
false
;
};
// Create a unique_ptr with a test transition state in it and set its initial
// Create
s
a unique_ptr with a test transition state in it and set its initial
// score.
std
::
unique_ptr
<
TestTransitionState
>
CreateState
(
float
score
)
{
std
::
unique_ptr
<
TestTransitionState
>
state
;
...
...
@@ -99,6 +113,16 @@ std::unique_ptr<TestTransitionState> CreateState(float score) {
return
state
;
}
// Creates a unique_ptr with a test transition state in it and set its initial
// score. Also, set gold-ness to TRUE.
std
::
unique_ptr
<
TestTransitionState
>
CreateGoldState
(
float
score
)
{
std
::
unique_ptr
<
TestTransitionState
>
state
;
state
.
reset
(
new
TestTransitionState
());
state
->
SetScore
(
score
);
state
->
SetGold
(
true
);
return
state
;
}
// Convenience accessor for the action field in TestTransitionState.
int
GetTransition
(
const
TransitionState
*
state
)
{
return
(
dynamic_cast
<
const
TestTransitionState
*>
(
state
))
->
transition_action_
;
...
...
@@ -114,11 +138,51 @@ void SetParentBeamIndex(TransitionState *state, int index) {
// *****************************************************************************
// Tests begin here.
// *****************************************************************************
TEST
(
BeamTest
,
AdvancesFromPredictionWithSingleBeamReturnsFalseOnNan
)
{
// Create a matrix of transitions.
constexpr
int
kNumTransitions
=
4
;
constexpr
int
kMatrixSize
=
kNumTransitions
;
constexpr
float
kNan
=
std
::
numeric_limits
<
double
>::
quiet_NaN
();
constexpr
float
kTransitionMatrix
[
kMatrixSize
]
=
{
1.0
,
kNan
,
2.0
,
3.0
};
constexpr
float
kOldScore
=
3.0
;
// Create the beam and transition it.
std
::
vector
<
std
::
unique_ptr
<
TestTransitionState
>>
states
;
states
.
push_back
(
CreateState
(
kOldScore
));
constexpr
int
kBeamSize
=
1
;
Beam
<
TestTransitionState
>
beam
(
kBeamSize
);
beam
.
SetFunctions
(
null_permissions
,
null_finality
,
transition_function
,
null_oracle
);
beam
.
Init
(
std
::
move
(
states
));
EXPECT_FALSE
(
beam
.
AdvanceFromPrediction
(
kTransitionMatrix
,
kMatrixSize
,
kNumTransitions
));
}
TEST
(
BeamTest
,
AdvancesFromPredictionWithSingleBeamReturnsFalseOnNoneAllowed
)
{
// Create a matrix of transitions.
constexpr
int
kNumTransitions
=
4
;
constexpr
int
kMatrixSize
=
kNumTransitions
;
constexpr
float
kTransitionMatrix
[
kMatrixSize
]
=
{
30.0
,
20.0
,
40.0
,
10.0
};
constexpr
float
kOldScore
=
3.0
;
// Create the beam and transition it.
std
::
vector
<
std
::
unique_ptr
<
TestTransitionState
>>
states
;
states
.
push_back
(
CreateState
(
kOldScore
));
constexpr
int
kBeamSize
=
1
;
Beam
<
TestTransitionState
>
beam
(
kBeamSize
);
auto
empty_permissions
=
[](
TestTransitionState
*
,
int
)
{
return
false
;
};
beam
.
SetFunctions
(
empty_permissions
,
null_finality
,
transition_function
,
null_oracle
);
beam
.
Init
(
std
::
move
(
states
));
EXPECT_FALSE
(
beam
.
AdvanceFromPrediction
(
kTransitionMatrix
,
kMatrixSize
,
kNumTransitions
));
}
TEST
(
BeamTest
,
AdvancesFromPredictionWithSingleBeam
)
{
// Create a matrix of transitions.
constexpr
int
kNumTransitions
=
4
;
constexpr
int
kMatrixSize
=
kNumTransitions
;
constexpr
float
m
atrix
[
kMatrixSize
]
=
{
30.0
,
20.0
,
40.0
,
10.0
};
constexpr
float
kTransitionM
atrix
[
kMatrixSize
]
=
{
30.0
,
20.0
,
40.0
,
10.0
};
constexpr
int
kBestTransition
=
2
;
constexpr
float
kOldScore
=
3.0
;
...
...
@@ -130,7 +194,7 @@ TEST(BeamTest, AdvancesFromPredictionWithSingleBeam) {
beam
.
SetFunctions
(
null_permissions
,
null_finality
,
transition_function
,
null_oracle
);
beam
.
Init
(
std
::
move
(
states
));
beam
.
AdvanceFromPrediction
(
m
atrix
,
kMatrixSize
,
kNumTransitions
);
beam
.
AdvanceFromPrediction
(
kTransitionM
atrix
,
kMatrixSize
,
kNumTransitions
);
// Validate the new beam.
EXPECT_EQ
(
beam
.
beam
().
size
(),
kBeamSize
);
...
...
@@ -139,7 +203,8 @@ TEST(BeamTest, AdvancesFromPredictionWithSingleBeam) {
EXPECT_EQ
(
GetTransition
(
beam
.
beam
().
at
(
0
)),
kBestTransition
);
// Make sure the state has had its score updated properly.
EXPECT_EQ
(
beam
.
beam
().
at
(
0
)
->
GetScore
(),
kOldScore
+
matrix
[
kBestTransition
]);
EXPECT_EQ
(
beam
.
beam
().
at
(
0
)
->
GetScore
(),
kOldScore
+
kTransitionMatrix
[
kBestTransition
]);
// Make sure that the beam index field is consistent with the actual beam idx.
EXPECT_EQ
(
beam
.
beam
().
at
(
0
)
->
GetBeamIndex
(),
0
);
...
...
@@ -152,12 +217,166 @@ TEST(BeamTest, AdvancesFromPredictionWithSingleBeam) {
EXPECT_EQ
(
history
.
at
(
1
).
at
(
0
),
0
);
}
TEST
(
BeamTest
,
NewlyCreatedStatesWithTrackingOffAreNotGold
)
{
// Create the beam.
std
::
vector
<
std
::
unique_ptr
<
TestTransitionState
>>
states
;
constexpr
float
kOldScore
=
3.0
;
states
.
push_back
(
CreateGoldState
(
kOldScore
));
constexpr
int
kBeamSize
=
1
;
Beam
<
TestTransitionState
>
beam
(
kBeamSize
);
beam
.
SetFunctions
(
null_permissions
,
null_finality
,
transition_function
,
null_oracle
);
// SetGoldTracking is false by default.
beam
.
SetGoldTracking
(
false
);
beam
.
Init
(
std
::
move
(
states
));
// Validate that the beam still has a gold state in it.
EXPECT_FALSE
(
beam
.
ContainsGold
());
}
TEST
(
BeamTest
,
AdvancesFromPredictionWithSingleBeamAndGoldTracking
)
{
// Create a matrix of transitions.
constexpr
int
kNumTransitions
=
4
;
constexpr
int
kMatrixSize
=
kNumTransitions
;
constexpr
float
kTransitionMatrix
[
kMatrixSize
]
=
{
30.0
,
20.0
,
40.0
,
10.0
};
constexpr
int
kBestTransition
=
2
;
constexpr
float
kOldScore
=
3.0
;
// Create the beam and transition it.
std
::
vector
<
std
::
unique_ptr
<
TestTransitionState
>>
states
;
states
.
push_back
(
CreateGoldState
(
kOldScore
));
constexpr
int
kBeamSize
=
1
;
Beam
<
TestTransitionState
>
beam
(
kBeamSize
);
// Create an oracle that indicates the best transition is index 2.
testing
::
MockFunction
<
const
vector
<
int
>
(
TestTransitionState
*
)
>
mock_oracle_function
;
vector
<
int
>
oracle_labels
=
{
1
,
2
};
EXPECT_CALL
(
mock_oracle_function
,
Call
(
_
)).
WillOnce
(
Return
(
oracle_labels
));
beam
.
SetFunctions
(
null_permissions
,
null_finality
,
transition_function
,
mock_oracle_function
.
AsStdFunction
());
beam
.
SetGoldTracking
(
true
);
beam
.
Init
(
std
::
move
(
states
));
beam
.
AdvanceFromPrediction
(
kTransitionMatrix
,
kMatrixSize
,
kNumTransitions
);
// Validate the new beam.
EXPECT_EQ
(
beam
.
beam
().
size
(),
kBeamSize
);
// Make sure the state has performed the expected transition.
EXPECT_EQ
(
GetTransition
(
beam
.
beam
()[
0
]),
kBestTransition
);
// Make sure the state has had its score updated properly.
EXPECT_EQ
(
beam
.
beam
()[
0
]
->
GetScore
(),
kOldScore
+
kTransitionMatrix
[
kBestTransition
]);
// Make sure that the beam index field is consistent with the actual beam idx.
EXPECT_EQ
(
beam
.
beam
()[
0
]
->
GetBeamIndex
(),
0
);
// Make sure that the beam_state accessor actually accesses the beam.
EXPECT_EQ
(
beam
.
beam
()[
0
],
beam
.
beam_state
(
0
));
// Validate the beam history field.
auto
history
=
beam
.
history
();
EXPECT_EQ
(
history
[
1
][
0
],
0
);
// Validate that the beam still has a gold state in it.
EXPECT_TRUE
(
beam
.
ContainsGold
());
}
TEST
(
BeamTest
,
AdvancesFromPredictionWithSingleBeamAndGoldTrackingFalloff
)
{
// Create a matrix of transitions.
constexpr
int
kNumTransitions
=
4
;
constexpr
int
kMatrixSize
=
kNumTransitions
;
constexpr
float
kTransitionMatrix
[
kMatrixSize
]
=
{
30.0
,
20.0
,
40.0
,
10.0
};
constexpr
int
kBestTransition
=
2
;
constexpr
float
kOldScore
=
3.0
;
// Create the beam and transition it.
std
::
vector
<
std
::
unique_ptr
<
TestTransitionState
>>
states
;
states
.
push_back
(
CreateGoldState
(
kOldScore
));
constexpr
int
kBeamSize
=
1
;
Beam
<
TestTransitionState
>
beam
(
kBeamSize
);
// Create an oracle that indicates the best transition is NOT index 2.
testing
::
MockFunction
<
const
vector
<
int
>
(
TestTransitionState
*
)
>
mock_oracle_function
;
vector
<
int
>
oracle_labels
=
{
0
,
1
};
EXPECT_CALL
(
mock_oracle_function
,
Call
(
_
)).
WillOnce
(
Return
(
oracle_labels
));
beam
.
SetFunctions
(
null_permissions
,
null_finality
,
transition_function
,
mock_oracle_function
.
AsStdFunction
());
beam
.
SetGoldTracking
(
true
);
beam
.
Init
(
std
::
move
(
states
));
beam
.
AdvanceFromPrediction
(
kTransitionMatrix
,
kMatrixSize
,
kNumTransitions
);
// Validate the new beam.
EXPECT_EQ
(
beam
.
beam
().
size
(),
kBeamSize
);
// Make sure the state has performed the expected transition.
EXPECT_EQ
(
GetTransition
(
beam
.
beam
()[
0
]),
kBestTransition
);
// Make sure the state has had its score updated properly.
EXPECT_EQ
(
beam
.
beam
()[
0
]
->
GetScore
(),
kOldScore
+
kTransitionMatrix
[
kBestTransition
]);
// Make sure that the beam index field is consistent with the actual beam idx.
EXPECT_EQ
(
beam
.
beam
()[
0
]
->
GetBeamIndex
(),
0
);
// Make sure that the beam_state accessor actually accesses the beam.
EXPECT_EQ
(
beam
.
beam
()[
0
],
beam
.
beam_state
(
0
));
// Validate the beam history field.
auto
history
=
beam
.
history
();
EXPECT_EQ
(
history
[
1
][
0
],
0
);
// Validate that the beam has no gold state in it.
EXPECT_FALSE
(
beam
.
ContainsGold
());
}
TEST
(
BeamTest
,
NonGoldBeamDoesNotInvokeOracle
)
{
// Create a matrix of transitions.
constexpr
int
kNumTransitions
=
4
;
constexpr
int
kMatrixSize
=
kNumTransitions
;
constexpr
float
kTransitionMatrix
[
kMatrixSize
]
=
{
30.0
,
20.0
,
40.0
,
10.0
};
constexpr
float
kOldScore
=
3.0
;
// Create the beam and transition it.
std
::
vector
<
std
::
unique_ptr
<
TestTransitionState
>>
states
;
states
.
push_back
(
CreateGoldState
(
kOldScore
));
auto
first_state
=
states
[
0
].
get
();
constexpr
int
kBeamSize
=
1
;
Beam
<
TestTransitionState
>
beam
(
kBeamSize
);
// Create an oracle that indicates the best transition is NOT index 2.
testing
::
MockFunction
<
const
vector
<
int
>
(
TestTransitionState
*
)
>
mock_oracle_function
;
vector
<
int
>
oracle_labels
=
{
0
,
1
};
EXPECT_CALL
(
mock_oracle_function
,
Call
(
first_state
))
.
WillOnce
(
Return
(
oracle_labels
));
beam
.
SetFunctions
(
null_permissions
,
null_finality
,
transition_function
,
mock_oracle_function
.
AsStdFunction
());
beam
.
SetGoldTracking
(
true
);
beam
.
Init
(
std
::
move
(
states
));
beam
.
AdvanceFromPrediction
(
kTransitionMatrix
,
kMatrixSize
,
kNumTransitions
);
// Validate that the beam has no gold state in it.
EXPECT_FALSE
(
beam
.
ContainsGold
());
// Advance again. Since the oracle function above expects to be called exactly
// once, another call should not match and cause a failure.
beam
.
AdvanceFromPrediction
(
kTransitionMatrix
,
kMatrixSize
,
kNumTransitions
);
}
TEST
(
BeamTest
,
AdvancingCreatesNewTransitions
)
{
// Create a matrix of transitions.
constexpr
int
kMaxBeamSize
=
8
;
constexpr
int
kNumTransitions
=
4
;
constexpr
int
kMatrixSize
=
kNumTransitions
*
kMaxBeamSize
;
constexpr
float
m
atrix
[
kMatrixSize
]
=
{
constexpr
float
kTransitionM
atrix
[
kMatrixSize
]
=
{
30.0
,
20.0
,
40.0
,
10.0
,
00.0
,
00.0
,
00.0
,
00.0
,
00.0
,
00.0
,
00.0
,
00.0
,
00.0
,
00.0
,
00.0
,
00.0
,
00.0
,
00.0
,
00.0
,
00.0
,
00.0
,
00.0
,
00.0
,
00.0
,
00.0
,
00.0
,
00.0
,
00.0
,
00.0
,
00.0
,
00.0
,
00.0
};
...
...
@@ -171,7 +390,7 @@ TEST(BeamTest, AdvancingCreatesNewTransitions) {
beam
.
SetFunctions
(
null_permissions
,
null_finality
,
transition_function
,
null_oracle
);
beam
.
Init
(
std
::
move
(
states
));
beam
.
AdvanceFromPrediction
(
m
atrix
,
kMatrixSize
,
kNumTransitions
);
beam
.
AdvanceFromPrediction
(
kTransitionM
atrix
,
kMatrixSize
,
kNumTransitions
);
// Validate the new beam.
EXPECT_EQ
(
beam
.
beam
().
size
(),
4
);
...
...
@@ -183,10 +402,10 @@ TEST(BeamTest, AdvancingCreatesNewTransitions) {
EXPECT_EQ
(
GetTransition
(
beam
.
beam
().
at
(
3
)),
3
);
// Make sure the state has had its score updated properly.
EXPECT_EQ
(
beam
.
beam
().
at
(
0
)
->
GetScore
(),
kOldScore
+
m
atrix
[
2
]);
EXPECT_EQ
(
beam
.
beam
().
at
(
1
)
->
GetScore
(),
kOldScore
+
m
atrix
[
0
]);
EXPECT_EQ
(
beam
.
beam
().
at
(
2
)
->
GetScore
(),
kOldScore
+
m
atrix
[
1
]);
EXPECT_EQ
(
beam
.
beam
().
at
(
3
)
->
GetScore
(),
kOldScore
+
m
atrix
[
3
]);
EXPECT_EQ
(
beam
.
beam
().
at
(
0
)
->
GetScore
(),
kOldScore
+
kTransitionM
atrix
[
2
]);
EXPECT_EQ
(
beam
.
beam
().
at
(
1
)
->
GetScore
(),
kOldScore
+
kTransitionM
atrix
[
0
]);
EXPECT_EQ
(
beam
.
beam
().
at
(
2
)
->
GetScore
(),
kOldScore
+
kTransitionM
atrix
[
1
]);
EXPECT_EQ
(
beam
.
beam
().
at
(
3
)
->
GetScore
(),
kOldScore
+
kTransitionM
atrix
[
3
]);
// Make sure that the beam index field is consistent with the actual beam idx.
for
(
int
i
=
0
;
i
<
beam
.
beam
().
size
();
++
i
)
{
...
...
@@ -212,7 +431,7 @@ TEST(BeamTest, MultipleElementBeamsAdvanceAllElements) {
constexpr
int
kNumTransitions
=
4
;
constexpr
int
kMatrixSize
=
kNumTransitions
*
kMaxBeamSize
;
constexpr
float
m
atrix
[
kMatrixSize
]
=
{
constexpr
float
kTransitionM
atrix
[
kMatrixSize
]
=
{
30.0
,
20.0
,
40.0
,
10.0
,
// State 0
31.0
,
21.0
,
41.0
,
11.0
,
00.0
,
00.0
,
00.0
,
00.0
,
00.0
,
00.0
,
00.0
,
00.0
,
00.0
,
00.0
,
00.0
,
00.0
,
00.0
,
00.0
,
00.0
,
00.0
,
...
...
@@ -229,7 +448,7 @@ TEST(BeamTest, MultipleElementBeamsAdvanceAllElements) {
beam
.
SetFunctions
(
null_permissions
,
null_finality
,
transition_function
,
null_oracle
);
beam
.
Init
(
std
::
move
(
states
));
beam
.
AdvanceFromPrediction
(
m
atrix
,
kMatrixSize
,
kNumTransitions
);
beam
.
AdvanceFromPrediction
(
kTransitionM
atrix
,
kMatrixSize
,
kNumTransitions
);
// Validate the new beam.
EXPECT_EQ
(
beam
.
beam
().
size
(),
8
);
...
...
@@ -247,14 +466,22 @@ TEST(BeamTest, MultipleElementBeamsAdvanceAllElements) {
EXPECT_EQ
(
GetTransition
(
beam
.
beam
().
at
(
7
)),
3
);
// Make sure the state has had its score updated properly.
EXPECT_EQ
(
beam
.
beam
().
at
(
0
)
->
GetScore
(),
kOldScores
[
1
]
+
matrix
[
6
]);
EXPECT_EQ
(
beam
.
beam
().
at
(
1
)
->
GetScore
(),
kOldScores
[
0
]
+
matrix
[
2
]);
EXPECT_EQ
(
beam
.
beam
().
at
(
2
)
->
GetScore
(),
kOldScores
[
1
]
+
matrix
[
4
]);
EXPECT_EQ
(
beam
.
beam
().
at
(
3
)
->
GetScore
(),
kOldScores
[
0
]
+
matrix
[
0
]);
EXPECT_EQ
(
beam
.
beam
().
at
(
4
)
->
GetScore
(),
kOldScores
[
1
]
+
matrix
[
5
]);
EXPECT_EQ
(
beam
.
beam
().
at
(
5
)
->
GetScore
(),
kOldScores
[
0
]
+
matrix
[
1
]);
EXPECT_EQ
(
beam
.
beam
().
at
(
6
)
->
GetScore
(),
kOldScores
[
1
]
+
matrix
[
7
]);
EXPECT_EQ
(
beam
.
beam
().
at
(
7
)
->
GetScore
(),
kOldScores
[
0
]
+
matrix
[
3
]);
EXPECT_EQ
(
beam
.
beam
().
at
(
0
)
->
GetScore
(),
kOldScores
[
1
]
+
kTransitionMatrix
[
6
]);
EXPECT_EQ
(
beam
.
beam
().
at
(
1
)
->
GetScore
(),
kOldScores
[
0
]
+
kTransitionMatrix
[
2
]);
EXPECT_EQ
(
beam
.
beam
().
at
(
2
)
->
GetScore
(),
kOldScores
[
1
]
+
kTransitionMatrix
[
4
]);
EXPECT_EQ
(
beam
.
beam
().
at
(
3
)
->
GetScore
(),
kOldScores
[
0
]
+
kTransitionMatrix
[
0
]);
EXPECT_EQ
(
beam
.
beam
().
at
(
4
)
->
GetScore
(),
kOldScores
[
1
]
+
kTransitionMatrix
[
5
]);
EXPECT_EQ
(
beam
.
beam
().
at
(
5
)
->
GetScore
(),
kOldScores
[
0
]
+
kTransitionMatrix
[
1
]);
EXPECT_EQ
(
beam
.
beam
().
at
(
6
)
->
GetScore
(),
kOldScores
[
1
]
+
kTransitionMatrix
[
7
]);
EXPECT_EQ
(
beam
.
beam
().
at
(
7
)
->
GetScore
(),
kOldScores
[
0
]
+
kTransitionMatrix
[
3
]);
// Make sure that the beam index field is consistent with the actual beam idx.
for
(
int
i
=
0
;
i
<
beam
.
beam
().
size
();
++
i
)
{
...
...
@@ -273,19 +500,255 @@ TEST(BeamTest, MultipleElementBeamsAdvanceAllElements) {
EXPECT_EQ
(
history
.
at
(
1
).
at
(
7
),
0
);
}
TEST
(
BeamTest
,
MultipleElementBeamsFailOnNan
)
{
// Create a matrix of transitions.
constexpr
int
kMaxBeamSize
=
8
;
constexpr
int
kNumTransitions
=
4
;
constexpr
int
kMatrixSize
=
kNumTransitions
*
kMaxBeamSize
;
constexpr
float
kNan
=
std
::
numeric_limits
<
double
>::
quiet_NaN
();
constexpr
float
kTransitionMatrix
[
kMatrixSize
]
=
{
30.0
,
20.0
,
40.0
,
10.0
,
// State 0
31.0
,
21.0
,
kNan
,
11.0
,
00.0
,
00.0
,
00.0
,
00.0
,
00.0
,
00.0
,
00.0
,
00.0
,
00.0
,
00.0
,
00.0
,
00.0
,
00.0
,
00.0
,
00.0
,
00.0
,
00.0
,
00.0
,
00.0
,
00.0
,
00.0
,
00.0
,
00.0
,
00.0
};
constexpr
float
kOldScores
[]
=
{
5.0
,
7.0
};
// Create the beam and transition it.
std
::
vector
<
std
::
unique_ptr
<
TestTransitionState
>>
states
;
states
.
push_back
(
CreateState
(
kOldScores
[
0
]));
states
.
push_back
(
CreateState
(
kOldScores
[
1
]));
Beam
<
TestTransitionState
>
beam
(
kMaxBeamSize
);
beam
.
SetFunctions
(
null_permissions
,
null_finality
,
transition_function
,
null_oracle
);
beam
.
Init
(
std
::
move
(
states
));
EXPECT_FALSE
(
beam
.
AdvanceFromPrediction
(
kTransitionMatrix
,
kMatrixSize
,
kNumTransitions
));
}
TEST
(
BeamTest
,
AdvancesFromPredictionWithMultipleStateBeamAndGoldTracking
)
{
// Create a matrix of transitions.
constexpr
int
kNumTransitions
=
4
;
constexpr
int
kMaxBeamSize
=
8
;
constexpr
int
kMatrixSize
=
kNumTransitions
*
kMaxBeamSize
;
constexpr
float
kTransitionMatrix
[
kMatrixSize
]
=
{
30.0
,
20.0
,
40.0
,
10.0
,
// State 0
31.0
,
21.0
,
41.0
,
11.0
,
// State 1
32.0
,
22.0
,
42.0
,
12.0
,
// State 2
33.0
,
23.0
,
43.0
,
13.0
,
// State 3
34.0
,
24.0
,
44.0
,
14.0
,
// State 4
35.0
,
25.0
,
45.0
,
15.0
,
// State 5
36.0
,
26.0
,
46.0
,
16.0
,
// State 6
37.0
,
27.0
,
47.0
,
17.0
};
// State 7
constexpr
float
kOldScores
[]
=
{
0.1
,
0.2
,
0.3
,
0.4
,
0.5
,
0.6
,
0.7
,
0.8
};
// Create the beam and transition it.
std
::
vector
<
std
::
unique_ptr
<
TestTransitionState
>>
states
;
states
.
push_back
(
CreateGoldState
(
kOldScores
[
0
]));
states
.
push_back
(
CreateGoldState
(
kOldScores
[
1
]));
states
.
push_back
(
CreateGoldState
(
kOldScores
[
2
]));
states
.
push_back
(
CreateGoldState
(
kOldScores
[
3
]));
states
.
push_back
(
CreateGoldState
(
kOldScores
[
4
]));
states
.
push_back
(
CreateGoldState
(
kOldScores
[
5
]));
states
.
push_back
(
CreateGoldState
(
kOldScores
[
6
]));
states
.
push_back
(
CreateGoldState
(
kOldScores
[
7
]));
// Arbitrarily choose state 4 as the golden state.
auto
gold_state
=
states
[
4
].
get
();
// Create an oracle that will only return one gold transition - on transition
// 2 for state 6 (arbitrarily).
testing
::
MockFunction
<
const
vector
<
int
>
(
TestTransitionState
*
)
>
mock_oracle_function
;
vector
<
int
>
oracle_labels
=
{
0
,
2
};
vector
<
int
>
null_labels
=
{};
EXPECT_CALL
(
mock_oracle_function
,
Call
(
testing
::
Ne
(
gold_state
)))
.
WillRepeatedly
(
Return
(
null_labels
));
EXPECT_CALL
(
mock_oracle_function
,
Call
(
gold_state
))
.
WillOnce
(
Return
(
oracle_labels
));
Beam
<
TestTransitionState
>
beam
(
kMaxBeamSize
);
beam
.
SetFunctions
(
null_permissions
,
null_finality
,
transition_function
,
mock_oracle_function
.
AsStdFunction
());
beam
.
SetGoldTracking
(
true
);
beam
.
Init
(
std
::
move
(
states
));
beam
.
AdvanceFromPrediction
(
kTransitionMatrix
,
kMatrixSize
,
kNumTransitions
);
// Validate the new beam.
EXPECT_EQ
(
beam
.
beam
().
size
(),
8
);
// Make sure the state has performed the expected transition.
// In this case, every state will perform transition 2.
EXPECT_EQ
(
GetTransition
(
beam
.
beam
()[
0
]),
2
);
EXPECT_EQ
(
GetTransition
(
beam
.
beam
()[
1
]),
2
);
EXPECT_EQ
(
GetTransition
(
beam
.
beam
()[
2
]),
2
);
EXPECT_EQ
(
GetTransition
(
beam
.
beam
()[
3
]),
2
);
EXPECT_EQ
(
GetTransition
(
beam
.
beam
()[
4
]),
2
);
EXPECT_EQ
(
GetTransition
(
beam
.
beam
()[
5
]),
2
);
EXPECT_EQ
(
GetTransition
(
beam
.
beam
()[
6
]),
2
);
EXPECT_EQ
(
GetTransition
(
beam
.
beam
()[
7
]),
2
);
// Make sure the state has had its score updated properly. (Note that row
// 0 had the smallest transition score, so it ends up on the bottom of the
// beam, and so forth.) For the matrix index, N*kNumTransitions gets into the
// correct state row and we add 2 since that was the transition index.
EXPECT_EQ
(
beam
.
beam
()[
0
]
->
GetScore
(),
kOldScores
[
7
]
+
kTransitionMatrix
[
7
*
kNumTransitions
+
2
]);
EXPECT_FALSE
(
beam
.
beam
()[
0
]
->
IsGold
());
EXPECT_EQ
(
beam
.
beam
()[
1
]
->
GetScore
(),
kOldScores
[
6
]
+
kTransitionMatrix
[
6
*
kNumTransitions
+
2
]);
EXPECT_FALSE
(
beam
.
beam
()[
1
]
->
IsGold
());
EXPECT_EQ
(
beam
.
beam
()[
2
]
->
GetScore
(),
kOldScores
[
5
]
+
kTransitionMatrix
[
5
*
kNumTransitions
+
2
]);
EXPECT_FALSE
(
beam
.
beam
()[
2
]
->
IsGold
());
// This should be the gold state.
EXPECT_EQ
(
beam
.
beam
()[
3
]
->
GetScore
(),
kOldScores
[
4
]
+
kTransitionMatrix
[
4
*
kNumTransitions
+
2
]);
EXPECT_TRUE
(
beam
.
beam
()[
3
]
->
IsGold
());
EXPECT_EQ
(
beam
.
beam
()[
4
]
->
GetScore
(),
kOldScores
[
3
]
+
kTransitionMatrix
[
3
*
kNumTransitions
+
2
]);
EXPECT_FALSE
(
beam
.
beam
()[
4
]
->
IsGold
());
EXPECT_EQ
(
beam
.
beam
()[
5
]
->
GetScore
(),
kOldScores
[
2
]
+
kTransitionMatrix
[
2
*
kNumTransitions
+
2
]);
EXPECT_FALSE
(
beam
.
beam
()[
5
]
->
IsGold
());
EXPECT_EQ
(
beam
.
beam
()[
6
]
->
GetScore
(),
kOldScores
[
1
]
+
kTransitionMatrix
[
1
*
kNumTransitions
+
2
]);
EXPECT_FALSE
(
beam
.
beam
()[
6
]
->
IsGold
());
EXPECT_EQ
(
beam
.
beam
()[
7
]
->
GetScore
(),
kOldScores
[
0
]
+
kTransitionMatrix
[
0
*
kNumTransitions
+
2
]);
EXPECT_FALSE
(
beam
.
beam
()[
7
]
->
IsGold
());
// Validate that the beam still has a gold state in it.
EXPECT_TRUE
(
beam
.
ContainsGold
());
}
TEST
(
BeamTest
,
AdvancesFromPredictionWithMultipleGoldStates
)
{
// Create a matrix of transitions.
constexpr
int
kNumTransitions
=
4
;
constexpr
int
kMaxBeamSize
=
8
;
constexpr
int
kMatrixSize
=
kNumTransitions
*
kMaxBeamSize
;
constexpr
float
kTransitionMatrix
[
kMatrixSize
]
=
{
30.0
,
20.0
,
40.0
,
10.0
,
// State 0
31.0
,
21.0
,
41.0
,
11.0
,
// State 1
32.0
,
22.0
,
42.0
,
12.0
,
// State 2
33.0
,
23.0
,
43.0
,
13.0
,
// State 3
54.0
,
24.0
,
44.0
,
14.0
,
// State 4 (gold - next will have both states)
35.0
,
25.0
,
45.0
,
15.0
,
// State 5
36.0
,
26.0
,
46.0
,
16.0
,
// State 6
37.0
,
27.0
,
47.0
,
17.0
};
// State 7
constexpr
float
kOldScores
[]
=
{
0.1
,
0.2
,
0.3
,
0.4
,
0.5
,
0.6
,
0.7
,
0.8
};
// Create the beam and transition it.
std
::
vector
<
std
::
unique_ptr
<
TestTransitionState
>>
states
;
states
.
push_back
(
CreateState
(
kOldScores
[
0
]));
states
.
push_back
(
CreateState
(
kOldScores
[
1
]));
states
.
push_back
(
CreateState
(
kOldScores
[
2
]));
states
.
push_back
(
CreateState
(
kOldScores
[
3
]));
states
.
push_back
(
CreateGoldState
(
kOldScores
[
4
]));
states
.
push_back
(
CreateState
(
kOldScores
[
5
]));
states
.
push_back
(
CreateState
(
kOldScores
[
6
]));
states
.
push_back
(
CreateState
(
kOldScores
[
7
]));
// Arbitrarily choose state 4 as the golden state.
auto
gold_state
=
states
[
4
].
get
();
// Create an oracle that will only return one gold transition - on transition
// 2 for state 6 (arbitrarily).
testing
::
MockFunction
<
const
vector
<
int
>
(
TestTransitionState
*
)
>
mock_oracle_function
;
vector
<
int
>
oracle_labels
=
{
0
,
2
};
vector
<
int
>
null_labels
=
{};
EXPECT_CALL
(
mock_oracle_function
,
Call
(
gold_state
))
.
WillOnce
(
Return
(
oracle_labels
))
.
WillOnce
(
Return
(
oracle_labels
));
Beam
<
TestTransitionState
>
beam
(
kMaxBeamSize
);
beam
.
SetFunctions
(
null_permissions
,
null_finality
,
transition_function
,
mock_oracle_function
.
AsStdFunction
());
beam
.
SetGoldTracking
(
true
);
beam
.
Init
(
std
::
move
(
states
));
beam
.
AdvanceFromPrediction
(
kTransitionMatrix
,
kMatrixSize
,
kNumTransitions
);
// Validate the new beam.
EXPECT_EQ
(
beam
.
beam
().
size
(),
8
);
// Make sure the state has performed the expected transition.
// In this case, every state will perform transition 2.
EXPECT_EQ
(
GetTransition
(
beam
.
beam
()[
0
]),
0
);
EXPECT_EQ
(
GetTransition
(
beam
.
beam
()[
1
]),
2
);
EXPECT_EQ
(
GetTransition
(
beam
.
beam
()[
2
]),
2
);
EXPECT_EQ
(
GetTransition
(
beam
.
beam
()[
3
]),
2
);
EXPECT_EQ
(
GetTransition
(
beam
.
beam
()[
4
]),
2
);
EXPECT_EQ
(
GetTransition
(
beam
.
beam
()[
5
]),
2
);
EXPECT_EQ
(
GetTransition
(
beam
.
beam
()[
6
]),
2
);
EXPECT_EQ
(
GetTransition
(
beam
.
beam
()[
7
]),
2
);
// Make sure the state has had its score updated properly. (Note that row
// 0 had the smallest transition score, so it ends up on the bottom of the
// beam, and so forth.) For the matrix index, N*kNumTransitions gets into the
// correct state row and we add 2 since that was the transition index.
// This should be a gold state.
EXPECT_EQ
(
beam
.
beam
()[
0
]
->
GetScore
(),
kOldScores
[
4
]
+
kTransitionMatrix
[
4
*
kNumTransitions
+
0
]);
EXPECT_TRUE
(
beam
.
beam
()[
0
]
->
IsGold
());
EXPECT_EQ
(
beam
.
beam
()[
1
]
->
GetScore
(),
kOldScores
[
7
]
+
kTransitionMatrix
[
7
*
kNumTransitions
+
2
]);
EXPECT_FALSE
(
beam
.
beam
()[
1
]
->
IsGold
());
EXPECT_EQ
(
beam
.
beam
()[
2
]
->
GetScore
(),
kOldScores
[
6
]
+
kTransitionMatrix
[
6
*
kNumTransitions
+
2
]);
EXPECT_FALSE
(
beam
.
beam
()[
2
]
->
IsGold
());
EXPECT_EQ
(
beam
.
beam
()[
3
]
->
GetScore
(),
kOldScores
[
5
]
+
kTransitionMatrix
[
5
*
kNumTransitions
+
2
]);
EXPECT_FALSE
(
beam
.
beam
()[
3
]
->
IsGold
());
// This should be a gold state.
EXPECT_EQ
(
beam
.
beam
()[
4
]
->
GetScore
(),
kOldScores
[
4
]
+
kTransitionMatrix
[
4
*
kNumTransitions
+
2
]);
EXPECT_TRUE
(
beam
.
beam
()[
4
]
->
IsGold
());
EXPECT_EQ
(
beam
.
beam
()[
5
]
->
GetScore
(),
kOldScores
[
3
]
+
kTransitionMatrix
[
3
*
kNumTransitions
+
2
]);
EXPECT_FALSE
(
beam
.
beam
()[
5
]
->
IsGold
());
EXPECT_EQ
(
beam
.
beam
()[
6
]
->
GetScore
(),
kOldScores
[
2
]
+
kTransitionMatrix
[
2
*
kNumTransitions
+
2
]);
EXPECT_FALSE
(
beam
.
beam
()[
6
]
->
IsGold
());
EXPECT_EQ
(
beam
.
beam
()[
7
]
->
GetScore
(),
kOldScores
[
1
]
+
kTransitionMatrix
[
1
*
kNumTransitions
+
2
]);
EXPECT_FALSE
(
beam
.
beam
()[
7
]
->
IsGold
());
// Validate that the beam still has a gold state in it.
EXPECT_TRUE
(
beam
.
ContainsGold
());
}
TEST
(
BeamTest
,
AdvancingDropsLowValuePredictions
)
{
// Create a matrix of transitions.
constexpr
int
kNumTransitions
=
4
;
constexpr
int
kMaxBeamSize
=
8
;
constexpr
int
kMatrixSize
=
kNumTransitions
*
kMaxBeamSize
;
constexpr
float
matrix
[
kMatrixSize
]
=
{
30.0
,
20.0
,
40.0
,
10.0
,
// State 0
31.0
,
21.0
,
41.0
,
11.0
,
// State 1
32.0
,
22.0
,
42.0
,
12.0
,
// State 2
33.0
,
23.0
,
43.0
,
13.0
,
// State 3
34.0
,
24.0
,
44.0
,
14.0
,
// State 4
35.0
,
25.0
,
45.0
,
15.0
,
// State 5
36.0
,
26.0
,
46.0
,
16.0
,
// State 6
37.0
,
27.0
,
47.0
,
17.0
};
// State 7
constexpr
float
kTransitionMatrix
[
kMatrixSize
]
=
{
30.0
,
20.0
,
40.0
,
10.0
,
// State 0
31.0
,
21.0
,
41.0
,
11.0
,
// State 1
32.0
,
22.0
,
42.0
,
12.0
,
// State 2
33.0
,
23.0
,
43.0
,
13.0
,
// State 3
34.0
,
24.0
,
44.0
,
14.0
,
// State 4
35.0
,
25.0
,
45.0
,
15.0
,
// State 5
36.0
,
26.0
,
46.0
,
16.0
,
// State 6
37.0
,
27.0
,
47.0
,
17.0
};
// State 7
constexpr
float
kOldScores
[]
=
{
0.1
,
0.2
,
0.3
,
0.4
,
0.5
,
0.6
,
0.7
,
0.8
};
// Create the beam and transition it.
...
...
@@ -302,7 +765,7 @@ TEST(BeamTest, AdvancingDropsLowValuePredictions) {
beam
.
SetFunctions
(
null_permissions
,
null_finality
,
transition_function
,
null_oracle
);
beam
.
Init
(
std
::
move
(
states
));
beam
.
AdvanceFromPrediction
(
m
atrix
,
kMatrixSize
,
kNumTransitions
);
beam
.
AdvanceFromPrediction
(
kTransitionM
atrix
,
kMatrixSize
,
kNumTransitions
);
// Validate the new beam.
EXPECT_EQ
(
beam
.
beam
().
size
(),
8
);
...
...
@@ -323,21 +786,21 @@ TEST(BeamTest, AdvancingDropsLowValuePredictions) {
// beam, and so forth.) For the matrix index, N*kNumTransitions gets into the
// correct state row and we add 2 since that was the transition index.
EXPECT_EQ
(
beam
.
beam
().
at
(
0
)
->
GetScore
(),
kOldScores
[
7
]
+
m
atrix
[
7
*
kNumTransitions
+
2
]);
kOldScores
[
7
]
+
kTransitionM
atrix
[
7
*
kNumTransitions
+
2
]);
EXPECT_EQ
(
beam
.
beam
().
at
(
1
)
->
GetScore
(),
kOldScores
[
6
]
+
m
atrix
[
6
*
kNumTransitions
+
2
]);
kOldScores
[
6
]
+
kTransitionM
atrix
[
6
*
kNumTransitions
+
2
]);
EXPECT_EQ
(
beam
.
beam
().
at
(
2
)
->
GetScore
(),
kOldScores
[
5
]
+
m
atrix
[
5
*
kNumTransitions
+
2
]);
kOldScores
[
5
]
+
kTransitionM
atrix
[
5
*
kNumTransitions
+
2
]);
EXPECT_EQ
(
beam
.
beam
().
at
(
3
)
->
GetScore
(),
kOldScores
[
4
]
+
m
atrix
[
4
*
kNumTransitions
+
2
]);
kOldScores
[
4
]
+
kTransitionM
atrix
[
4
*
kNumTransitions
+
2
]);
EXPECT_EQ
(
beam
.
beam
().
at
(
4
)
->
GetScore
(),
kOldScores
[
3
]
+
m
atrix
[
3
*
kNumTransitions
+
2
]);
kOldScores
[
3
]
+
kTransitionM
atrix
[
3
*
kNumTransitions
+
2
]);
EXPECT_EQ
(
beam
.
beam
().
at
(
5
)
->
GetScore
(),
kOldScores
[
2
]
+
m
atrix
[
2
*
kNumTransitions
+
2
]);
kOldScores
[
2
]
+
kTransitionM
atrix
[
2
*
kNumTransitions
+
2
]);
EXPECT_EQ
(
beam
.
beam
().
at
(
6
)
->
GetScore
(),
kOldScores
[
1
]
+
m
atrix
[
1
*
kNumTransitions
+
2
]);
kOldScores
[
1
]
+
kTransitionM
atrix
[
1
*
kNumTransitions
+
2
]);
EXPECT_EQ
(
beam
.
beam
().
at
(
7
)
->
GetScore
(),
kOldScores
[
0
]
+
m
atrix
[
0
*
kNumTransitions
+
2
]);
kOldScores
[
0
]
+
kTransitionM
atrix
[
0
*
kNumTransitions
+
2
]);
// Make sure that the beam index field is consistent with the actual beam idx.
for
(
int
i
=
0
;
i
<
beam
.
beam
().
size
();
++
i
)
{
...
...
@@ -358,7 +821,9 @@ TEST(BeamTest, AdvancingDropsLowValuePredictions) {
TEST
(
BeamTest
,
AdvancesFromOracleWithSingleBeam
)
{
// Create an oracle function for this state.
constexpr
int
kOracleLabel
=
3
;
auto
oracle_function
=
[](
TransitionState
*
)
{
return
kOracleLabel
;
};
auto
oracle_function
=
[](
TransitionState
*
)
->
const
vector
<
int
>
{
return
{
kOracleLabel
};
};
// Create the beam and transition it.
std
::
vector
<
std
::
unique_ptr
<
TestTransitionState
>>
states
;
...
...
@@ -392,21 +857,24 @@ TEST(BeamTest, AdvancesFromOracleWithMultipleStates) {
// Create a beam with 8 transition states.
std
::
vector
<
std
::
unique_ptr
<
TestTransitionState
>>
states
;
states
.
reserve
(
kMaxBeamSize
);
for
(
int
i
=
0
;
i
<
kMaxBeamSize
;
++
i
)
{
// This is nonzero to test the oracle holding scores
to 0
.
// This is nonzero to test the oracle holding scores
constant
.
states
.
push_back
(
CreateState
(
10.0
));
}
std
::
vector
<
int
>
expected_actions
;
// Create an oracle function for this state. Use mocks for finer control.
testing
::
MockFunction
<
int
(
TestTransitionState
*
)
>
mock_oracle_function
;
testing
::
MockFunction
<
const
vector
<
int
>
(
TestTransitionState
*
)
>
mock_oracle_function
;
for
(
int
i
=
0
;
i
<
kMaxBeamSize
;
++
i
)
{
// We expect each state to be queried for its oracle label,
// and then to be transitioned in place with its oracle label.
int
oracle_label
=
i
%
3
;
// 3 is arbitrary.
vector
<
int
>
oracle_labels
=
{
oracle_label
};
EXPECT_CALL
(
mock_oracle_function
,
Call
(
states
.
at
(
i
).
get
()))
.
WillOnce
(
Return
(
oracle_label
));
.
WillOnce
(
Return
(
oracle_label
s
));
expected_actions
.
push_back
(
oracle_label
);
}
...
...
@@ -435,6 +903,7 @@ TEST(BeamTest, ReportsNonFinality) {
// Create a beam with 8 transition states.
std
::
vector
<
std
::
unique_ptr
<
TestTransitionState
>>
states
;
states
.
reserve
(
kMaxBeamSize
);
for
(
int
i
=
0
;
i
<
kMaxBeamSize
;
++
i
)
{
// This is nonzero to test the oracle holding scores to 0.
states
.
push_back
(
CreateState
(
10.0
));
...
...
@@ -467,6 +936,7 @@ TEST(BeamTest, ReportsFinality) {
// Create a beam with 8 transition states.
std
::
vector
<
std
::
unique_ptr
<
TestTransitionState
>>
states
;
states
.
reserve
(
kMaxBeamSize
);
for
(
int
i
=
0
;
i
<
kMaxBeamSize
;
++
i
)
{
// This is nonzero to test the oracle holding scores to 0.
states
.
push_back
(
CreateState
(
10.0
));
...
...
@@ -493,7 +963,7 @@ TEST(BeamTest, IgnoresForbiddenTransitionActions) {
constexpr
int
kMaxBeamSize
=
4
;
constexpr
int
kNumTransitions
=
4
;
constexpr
int
kMatrixSize
=
kNumTransitions
*
kMaxBeamSize
;
constexpr
float
m
atrix
[
kMatrixSize
]
=
{
constexpr
float
kTransitionM
atrix
[
kMatrixSize
]
=
{
10.0
,
1000.0
,
40.0
,
30.0
,
00.0
,
0000.0
,
00.0
,
00.0
,
00.0
,
0000.0
,
00.0
,
00.0
,
00.0
,
0000.0
,
00.0
,
00.0
};
constexpr
float
kOldScore
=
4.0
;
...
...
@@ -518,7 +988,7 @@ TEST(BeamTest, IgnoresForbiddenTransitionActions) {
beam
.
SetFunctions
(
mock_permission_function
.
AsStdFunction
(),
null_finality
,
transition_function
,
null_oracle
);
beam
.
Init
(
std
::
move
(
states
));
beam
.
AdvanceFromPrediction
(
m
atrix
,
kMatrixSize
,
kNumTransitions
);
beam
.
AdvanceFromPrediction
(
kTransitionM
atrix
,
kMatrixSize
,
kNumTransitions
);
// Validate the new beam.
EXPECT_EQ
(
beam
.
beam
().
size
(),
3
);
...
...
@@ -529,9 +999,9 @@ TEST(BeamTest, IgnoresForbiddenTransitionActions) {
EXPECT_EQ
(
GetTransition
(
beam
.
beam
().
at
(
2
)),
0
);
// Make sure the state has had its score updated properly.
EXPECT_EQ
(
beam
.
beam
().
at
(
0
)
->
GetScore
(),
kOldScore
+
m
atrix
[
2
]);
EXPECT_EQ
(
beam
.
beam
().
at
(
1
)
->
GetScore
(),
kOldScore
+
m
atrix
[
3
]);
EXPECT_EQ
(
beam
.
beam
().
at
(
2
)
->
GetScore
(),
kOldScore
+
m
atrix
[
0
]);
EXPECT_EQ
(
beam
.
beam
().
at
(
0
)
->
GetScore
(),
kOldScore
+
kTransitionM
atrix
[
2
]);
EXPECT_EQ
(
beam
.
beam
().
at
(
1
)
->
GetScore
(),
kOldScore
+
kTransitionM
atrix
[
3
]);
EXPECT_EQ
(
beam
.
beam
().
at
(
2
)
->
GetScore
(),
kOldScore
+
kTransitionM
atrix
[
0
]);
// Make sure that the beam index field is consistent with the actual beam idx.
for
(
int
i
=
0
;
i
<
beam
.
beam
().
size
();
++
i
)
{
...
...
@@ -551,7 +1021,7 @@ TEST(BeamTest, BadlySizedMatrixDies) {
// Create a matrix of transitions.
constexpr
int
kNumTransitions
=
4
;
constexpr
int
kMatrixSize
=
4
;
// We have a max beam size of 4; should be 16.
constexpr
float
m
atrix
[
kMatrixSize
]
=
{
30.0
,
20.0
,
40.0
,
10.0
};
constexpr
float
kTransitionM
atrix
[
kMatrixSize
]
=
{
30.0
,
20.0
,
40.0
,
10.0
};
// Create the beam and transition it.
std
::
vector
<
std
::
unique_ptr
<
TestTransitionState
>>
states
;
...
...
@@ -564,7 +1034,8 @@ TEST(BeamTest, BadlySizedMatrixDies) {
beam
.
Init
(
std
::
move
(
states
));
// This matrix should have 8 elements, not 4, so this should die.
EXPECT_DEATH
(
beam
.
AdvanceFromPrediction
(
matrix
,
kMatrixSize
,
kNumTransitions
),
EXPECT_DEATH
(
beam
.
AdvanceFromPrediction
(
kTransitionMatrix
,
kMatrixSize
,
kNumTransitions
),
"Transition matrix size does not match max beam size
\\
* number "
"of state transitions"
);
}
...
...
@@ -573,6 +1044,7 @@ TEST(BeamTest, BadlySizedBeamInitializationDies) {
// Create an initialization beam too large for the max beam size.
constexpr
int
kMaxBeamSize
=
4
;
std
::
vector
<
std
::
unique_ptr
<
TestTransitionState
>>
states
;
states
.
reserve
(
kMaxBeamSize
+
1
);
for
(
int
i
=
0
;
i
<
kMaxBeamSize
+
1
;
++
i
)
{
states
.
push_back
(
CreateState
(
0.0
));
}
...
...
@@ -590,6 +1062,7 @@ TEST(BeamTest, ValidBeamIndicesAfterBeamInitialization) {
// Create a standard beam.
constexpr
int
kMaxBeamSize
=
4
;
std
::
vector
<
std
::
unique_ptr
<
TestTransitionState
>>
states
;
states
.
reserve
(
kMaxBeamSize
);
for
(
int
i
=
0
;
i
<
kMaxBeamSize
;
++
i
)
{
states
.
push_back
(
CreateState
(
0.0
));
}
...
...
@@ -611,7 +1084,7 @@ TEST(BeamTest, FindPreviousIndexTracesHistory) {
constexpr
int
kNumTransitions
=
4
;
constexpr
int
kMaxBeamSize
=
8
;
constexpr
int
kMatrixSize
=
kNumTransitions
*
kMaxBeamSize
;
constexpr
float
m
atrix
[
kMatrixSize
]
=
{
constexpr
float
kTransitionM
atrix
[
kMatrixSize
]
=
{
30.0
,
20.0
,
40.0
,
10.0
,
// State 0
31.0
,
21.0
,
41.0
,
11.0
,
// State 1
00.0
,
00.0
,
00.0
,
00.0
,
00.0
,
00.0
,
00.0
,
00.0
,
00.0
,
00.0
,
00.0
,
00.0
,
...
...
@@ -632,7 +1105,7 @@ TEST(BeamTest, FindPreviousIndexTracesHistory) {
beam
.
SetFunctions
(
null_permissions
,
null_finality
,
transition_function
,
null_oracle
);
beam
.
Init
(
std
::
move
(
states
));
beam
.
AdvanceFromPrediction
(
m
atrix
,
kMatrixSize
,
kNumTransitions
);
beam
.
AdvanceFromPrediction
(
kTransitionM
atrix
,
kMatrixSize
,
kNumTransitions
);
// Validate the new beam.
EXPECT_EQ
(
beam
.
beam
().
size
(),
8
);
...
...
@@ -650,14 +1123,22 @@ TEST(BeamTest, FindPreviousIndexTracesHistory) {
EXPECT_EQ
(
GetTransition
(
beam
.
beam
().
at
(
7
)),
3
);
// Make sure the state has had its score updated properly.
EXPECT_EQ
(
beam
.
beam
().
at
(
0
)
->
GetScore
(),
kOldScores
[
1
]
+
matrix
[
6
]);
EXPECT_EQ
(
beam
.
beam
().
at
(
1
)
->
GetScore
(),
kOldScores
[
0
]
+
matrix
[
2
]);
EXPECT_EQ
(
beam
.
beam
().
at
(
2
)
->
GetScore
(),
kOldScores
[
1
]
+
matrix
[
4
]);
EXPECT_EQ
(
beam
.
beam
().
at
(
3
)
->
GetScore
(),
kOldScores
[
0
]
+
matrix
[
0
]);
EXPECT_EQ
(
beam
.
beam
().
at
(
4
)
->
GetScore
(),
kOldScores
[
1
]
+
matrix
[
5
]);
EXPECT_EQ
(
beam
.
beam
().
at
(
5
)
->
GetScore
(),
kOldScores
[
0
]
+
matrix
[
1
]);
EXPECT_EQ
(
beam
.
beam
().
at
(
6
)
->
GetScore
(),
kOldScores
[
1
]
+
matrix
[
7
]);
EXPECT_EQ
(
beam
.
beam
().
at
(
7
)
->
GetScore
(),
kOldScores
[
0
]
+
matrix
[
3
]);
EXPECT_EQ
(
beam
.
beam
().
at
(
0
)
->
GetScore
(),
kOldScores
[
1
]
+
kTransitionMatrix
[
6
]);
EXPECT_EQ
(
beam
.
beam
().
at
(
1
)
->
GetScore
(),
kOldScores
[
0
]
+
kTransitionMatrix
[
2
]);
EXPECT_EQ
(
beam
.
beam
().
at
(
2
)
->
GetScore
(),
kOldScores
[
1
]
+
kTransitionMatrix
[
4
]);
EXPECT_EQ
(
beam
.
beam
().
at
(
3
)
->
GetScore
(),
kOldScores
[
0
]
+
kTransitionMatrix
[
0
]);
EXPECT_EQ
(
beam
.
beam
().
at
(
4
)
->
GetScore
(),
kOldScores
[
1
]
+
kTransitionMatrix
[
5
]);
EXPECT_EQ
(
beam
.
beam
().
at
(
5
)
->
GetScore
(),
kOldScores
[
0
]
+
kTransitionMatrix
[
1
]);
EXPECT_EQ
(
beam
.
beam
().
at
(
6
)
->
GetScore
(),
kOldScores
[
1
]
+
kTransitionMatrix
[
7
]);
EXPECT_EQ
(
beam
.
beam
().
at
(
7
)
->
GetScore
(),
kOldScores
[
0
]
+
kTransitionMatrix
[
3
]);
// Make sure that the beam index field is consistent with the actual beam idx.
for
(
int
i
=
0
;
i
<
beam
.
beam
().
size
();
++
i
)
{
...
...
research/syntaxnet/dragnn/core/component_registry.h
View file @
4364390a
...
...
@@ -13,12 +13,17 @@
// limitations under the License.
// =============================================================================
#ifndef
NLP_SAFT_OPENSOURCE_
DRAGNN_CORE_COMPONENT_REGISTRY_H_
#define
NLP_SAFT_OPENSOURCE_
DRAGNN_CORE_COMPONENT_REGISTRY_H_
#ifndef DRAGNN_CORE_COMPONENT_REGISTRY_H_
#define DRAGNN_CORE_COMPONENT_REGISTRY_H_
#include "dragnn/core/interfaces/component.h"
#include "syntaxnet/registry.h"
namespace
syntaxnet
{
// Class registry for DRAGNN components.
DECLARE_SYNTAXNET_CLASS_REGISTRY
(
"DRAGNN Component"
,
dragnn
::
Component
);
}
// namespace syntaxnet
// Macro to add a component to the registry. This macro associates a class with
// its class name as a string, so FooComponent would be associated with the
// string "FooComponent".
...
...
@@ -26,4 +31,4 @@
REGISTER_SYNTAXNET_CLASS_COMPONENT(syntaxnet::dragnn::Component, #component, \
component)
#endif //
NLP_SAFT_OPENSOURCE_
DRAGNN_CORE_COMPONENT_REGISTRY_H_
#endif // DRAGNN_CORE_COMPONENT_REGISTRY_H_
research/syntaxnet/dragnn/core/compute_session.h
View file @
4364390a
...
...
@@ -13,13 +13,14 @@
// limitations under the License.
// =============================================================================
#ifndef
NLP_SAFT_OPENSOURCE_
DRAGNN_CORE_COMPUTE_SESSION_H_
#define
NLP_SAFT_OPENSOURCE_
DRAGNN_CORE_COMPUTE_SESSION_H_
#ifndef DRAGNN_CORE_COMPUTE_SESSION_H_
#define DRAGNN_CORE_COMPUTE_SESSION_H_
#include <string>
#include "dragnn/components/util/bulk_feature_extractor.h"
#include "dragnn/core/index_translator.h"
#include "dragnn/core/input_batch_cache.h"
#include "dragnn/core/interfaces/component.h"
#include "dragnn/protos/spec.pb.h"
#include "dragnn/protos/trace.pb.h"
...
...
@@ -64,10 +65,11 @@ class ComputeSession {
// Advance the given component using the component's oracle.
virtual
void
AdvanceFromOracle
(
const
string
&
component_name
)
=
0
;
// Advance the given component using the given score matrix.
virtual
void
AdvanceFromPrediction
(
const
string
&
component_name
,
const
float
score_matrix
[],
int
score_matrix_length
)
=
0
;
// Advance the given component using the given score matrix, which is
// |num_items| x |num_actions|.
virtual
bool
AdvanceFromPrediction
(
const
string
&
component_name
,
const
float
*
score_matrix
,
int
num_items
,
int
num_actions
)
=
0
;
// Get the input features for the given component and channel. This passes
// through to the relevant Component's GetFixedFeatures() call.
...
...
@@ -84,6 +86,15 @@ class ComputeSession {
virtual
int
BulkGetInputFeatures
(
const
string
&
component_name
,
const
BulkFeatureExtractor
&
extractor
)
=
0
;
// Directly computes the embedding matrix for all channels, advancing the
// component via the oracle until it is terminal. This call takes a vector
// of float embedding matrices, one per channel, in channel order.
virtual
void
BulkEmbedFixedFeatures
(
const
string
&
component_name
,
int
batch_size_padding
,
int
num_steps_padding
,
int
output_array_size
,
const
vector
<
const
float
*>
&
per_channel_embeddings
,
float
*
embedding_output
)
=
0
;
// Get the input features for the given component and channel. This function
// can return empty LinkFeatures protos, which represent unused padding slots
// in the output weight tensor.
...
...
@@ -111,6 +122,10 @@ class ComputeSession {
// Provides the ComputeSession with a batch of data to compute.
virtual
void
SetInputData
(
const
std
::
vector
<
string
>
&
data
)
=
0
;
// Like SetInputData(), but accepts an InputBatchCache directly, potentially
// bypassing de-serialization.
virtual
void
SetInputBatchCache
(
std
::
unique_ptr
<
InputBatchCache
>
batch
)
=
0
;
// Resets all components owned by this ComputeSession.
virtual
void
ResetSession
()
=
0
;
...
...
@@ -127,9 +142,14 @@ class ComputeSession {
// validate correct construction of translators in tests.
virtual
const
std
::
vector
<
const
IndexTranslator
*>
Translators
(
const
string
&
component_name
)
const
=
0
;
// Get a given component. CHECK-fail if the component's IsReady method
// returns false.
virtual
Component
*
GetReadiedComponent
(
const
string
&
component_name
)
const
=
0
;
};
}
// namespace dragnn
}
// namespace syntaxnet
#endif //
NLP_SAFT_OPENSOURCE_
DRAGNN_CORE_COMPUTE_SESSION_H_
#endif // DRAGNN_CORE_COMPUTE_SESSION_H_
research/syntaxnet/dragnn/core/compute_session_impl.cc
View file @
4364390a
...
...
@@ -161,11 +161,11 @@ void ComputeSessionImpl::AdvanceFromOracle(const string &component_name) {
GetReadiedComponent
(
component_name
)
->
AdvanceFromOracle
();
}
void
ComputeSessionImpl
::
AdvanceFromPrediction
(
const
string
&
component_name
,
const
float
score_matrix
[]
,
int
score_matrix_length
)
{
GetReadiedComponent
(
component_name
)
->
AdvanceFromPrediction
(
score_matrix
,
score_matrix_length
);
bool
ComputeSessionImpl
::
AdvanceFromPrediction
(
const
string
&
component_name
,
const
float
*
score_matrix
,
int
num_items
,
int
num_actions
)
{
return
GetReadiedComponent
(
component_name
)
->
AdvanceFromPrediction
(
score_matrix
,
num_items
,
num_actions
);
}
int
ComputeSessionImpl
::
GetInputFeatures
(
...
...
@@ -182,6 +182,16 @@ int ComputeSessionImpl::BulkGetInputFeatures(
return
GetReadiedComponent
(
component_name
)
->
BulkGetFixedFeatures
(
extractor
);
}
void
ComputeSessionImpl
::
BulkEmbedFixedFeatures
(
const
string
&
component_name
,
int
batch_size_padding
,
int
num_steps_padding
,
int
output_array_size
,
const
vector
<
const
float
*>
&
per_channel_embeddings
,
float
*
embedding_output
)
{
return
GetReadiedComponent
(
component_name
)
->
BulkEmbedFixedFeatures
(
batch_size_padding
,
num_steps_padding
,
output_array_size
,
per_channel_embeddings
,
embedding_output
);
}
std
::
vector
<
LinkFeatures
>
ComputeSessionImpl
::
GetTranslatedLinkFeatures
(
const
string
&
component_name
,
int
channel_id
)
{
auto
*
component
=
GetReadiedComponent
(
component_name
);
...
...
@@ -288,6 +298,11 @@ void ComputeSessionImpl::SetInputData(const std::vector<string> &data) {
input_data_
.
reset
(
new
InputBatchCache
(
data
));
}
void
ComputeSessionImpl
::
SetInputBatchCache
(
std
::
unique_ptr
<
InputBatchCache
>
batch
)
{
input_data_
=
std
::
move
(
batch
);
}
void
ComputeSessionImpl
::
ResetSession
()
{
// Reset all component states.
for
(
auto
&
component_pair
:
components_
)
{
...
...
@@ -308,6 +323,7 @@ const std::vector<const IndexTranslator *> ComputeSessionImpl::Translators(
const
string
&
component_name
)
const
{
auto
translators
=
GetTranslators
(
component_name
);
std
::
vector
<
const
IndexTranslator
*>
const_translators
;
const_translators
.
reserve
(
translators
.
size
());
for
(
const
auto
&
translator
:
translators
)
{
const_translators
.
push_back
(
translator
);
}
...
...
research/syntaxnet/dragnn/core/compute_session_impl.h
View file @
4364390a
...
...
@@ -13,8 +13,8 @@
// limitations under the License.
// =============================================================================
#ifndef
NLP_SAFT_OPENSOURCE_
DRAGNN_CORE_COMPUTE_SESSION_IMPL_H_
#define
NLP_SAFT_OPENSOURCE_
DRAGNN_CORE_COMPUTE_SESSION_IMPL_H_
#ifndef DRAGNN_CORE_COMPUTE_SESSION_IMPL_H_
#define DRAGNN_CORE_COMPUTE_SESSION_IMPL_H_
#include <memory>
...
...
@@ -55,9 +55,9 @@ class ComputeSessionImpl : public ComputeSession {
void
AdvanceFromOracle
(
const
string
&
component_name
)
override
;
void
AdvanceFromPrediction
(
const
string
&
component_name
,
const
float
score_matrix
[]
,
int
score_matrix_length
)
override
;
bool
AdvanceFromPrediction
(
const
string
&
component_name
,
const
float
*
score_matrix
,
int
num_items
,
int
num_actions
)
override
;
int
GetInputFeatures
(
const
string
&
component_name
,
std
::
function
<
int32
*
(
int
)
>
allocate_indices
,
...
...
@@ -68,6 +68,12 @@ class ComputeSessionImpl : public ComputeSession {
int
BulkGetInputFeatures
(
const
string
&
component_name
,
const
BulkFeatureExtractor
&
extractor
)
override
;
void
BulkEmbedFixedFeatures
(
const
string
&
component_name
,
int
batch_size_padding
,
int
num_steps_padding
,
int
output_array_size
,
const
vector
<
const
float
*>
&
per_channel_embeddings
,
float
*
embedding_output
)
override
;
std
::
vector
<
LinkFeatures
>
GetTranslatedLinkFeatures
(
const
string
&
component_name
,
int
channel_id
)
override
;
...
...
@@ -84,6 +90,8 @@ class ComputeSessionImpl : public ComputeSession {
void
SetInputData
(
const
std
::
vector
<
string
>
&
data
)
override
;
void
SetInputBatchCache
(
std
::
unique_ptr
<
InputBatchCache
>
batch
)
override
;
void
ResetSession
()
override
;
void
SetTracing
(
bool
tracing_on
)
override
;
...
...
@@ -95,14 +103,14 @@ class ComputeSessionImpl : public ComputeSession {
const
std
::
vector
<
const
IndexTranslator
*>
Translators
(
const
string
&
component_name
)
const
override
;
// Get a given component. CHECK-fail if the component's IsReady method
// returns false.
Component
*
GetReadiedComponent
(
const
string
&
component_name
)
const
override
;
private:
// Get a given component. Fails if the component is not found.
Component
*
GetComponent
(
const
string
&
component_name
)
const
;
// Get a given component. CHECK-fail if the component's IsReady method
// returns false.
Component
*
GetReadiedComponent
(
const
string
&
component_name
)
const
;
// Get the index translators for the given component.
const
std
::
vector
<
IndexTranslator
*>
&
GetTranslators
(
const
string
&
component_name
)
const
;
...
...
@@ -154,4 +162,4 @@ class ComputeSessionImpl : public ComputeSession {
}
// namespace dragnn
}
// namespace syntaxnet
#endif //
NLP_SAFT_OPENSOURCE_
DRAGNN_CORE_COMPUTE_SESSION_IMPL_H_
#endif // DRAGNN_CORE_COMPUTE_SESSION_IMPL_H_
research/syntaxnet/dragnn/core/compute_session_impl_test.cc
View file @
4364390a
...
...
@@ -22,7 +22,9 @@
#include "dragnn/core/component_registry.h"
#include "dragnn/core/compute_session.h"
#include "dragnn/core/compute_session_pool.h"
#include "dragnn/core/input_batch_cache.h"
#include "dragnn/core/interfaces/component.h"
#include "dragnn/core/interfaces/input_batch.h"
#include "dragnn/core/test/generic.h"
#include "dragnn/core/test/mock_component.h"
#include "dragnn/core/test/mock_transition_state.h"
...
...
@@ -65,8 +67,10 @@ class TestComponentType1 : public Component {
int
GetSourceBeamIndex
(
int
current_index
,
int
batch
)
const
override
{
return
0
;
}
void
AdvanceFromPrediction
(
const
float
transition_matrix
[],
int
matrix_length
)
override
{}
bool
AdvanceFromPrediction
(
const
float
*
score_matrix
,
int
num_items
,
int
num_actions
)
override
{
return
true
;
}
void
AdvanceFromOracle
()
override
{}
bool
IsTerminal
()
const
override
{
return
true
;
}
std
::
function
<
int
(
int
,
int
,
int
)
>
GetStepLookupFunction
(
...
...
@@ -83,6 +87,10 @@ class TestComponentType1 : public Component {
int
channel_id
)
const
override
{
return
0
;
}
void
BulkEmbedFixedFeatures
(
int
batch_size_padding
,
int
num_steps_padding
,
int
embedding_size
,
const
vector
<
const
float
*>
&
per_channel_embeddings
,
float
*
embedding_output
)
override
{}
int
BulkGetFixedFeatures
(
const
BulkFeatureExtractor
&
extractor
)
override
{
return
0
;
}
...
...
@@ -133,8 +141,10 @@ class TestComponentType2 : public Component {
int
GetSourceBeamIndex
(
int
current_index
,
int
batch
)
const
override
{
return
0
;
}
void
AdvanceFromPrediction
(
const
float
transition_matrix
[],
int
matrix_length
)
override
{}
bool
AdvanceFromPrediction
(
const
float
*
score_matrix
,
int
num_items
,
int
num_actions
)
override
{
return
true
;
}
void
AdvanceFromOracle
()
override
{}
bool
IsTerminal
()
const
override
{
return
true
;
}
std
::
function
<
int
(
int
,
int
,
int
)
>
GetStepLookupFunction
(
...
...
@@ -151,6 +161,10 @@ class TestComponentType2 : public Component {
int
channel_id
)
const
override
{
return
0
;
}
void
BulkEmbedFixedFeatures
(
int
batch_size_padding
,
int
num_steps_padding
,
int
embedding_size
,
const
vector
<
const
float
*>
&
per_channel_embeddings
,
float
*
embedding_output
)
override
{}
int
BulkGetFixedFeatures
(
const
BulkFeatureExtractor
&
extractor
)
override
{
return
0
;
}
...
...
@@ -201,8 +215,14 @@ class UnreadyComponent : public Component {
int
GetSourceBeamIndex
(
int
current_index
,
int
batch
)
const
override
{
return
0
;
}
void
AdvanceFromPrediction
(
const
float
transition_matrix
[],
int
matrix_length
)
override
{}
bool
AdvanceFromPrediction
(
const
float
*
score_matrix
,
int
num_items
,
int
num_actions
)
override
{
return
true
;
}
void
BulkEmbedFixedFeatures
(
int
batch_size_padding
,
int
num_steps_padding
,
int
embedding_size
,
const
vector
<
const
float
*>
&
per_channel_embeddings
,
float
*
embedding_output
)
override
{}
void
AdvanceFromOracle
()
override
{}
bool
IsTerminal
()
const
override
{
return
false
;
}
std
::
function
<
int
(
int
,
int
,
int
)
>
GetStepLookupFunction
(
...
...
@@ -254,6 +274,18 @@ class ComputeSessionImplTestPoolAccessor {
}
};
// An InputBatch that uses the serialized data directly.
class
IdentityBatch
:
public
InputBatch
{
public:
// Implements InputBatch.
void
SetData
(
const
std
::
vector
<
string
>
&
data
)
override
{
data_
=
data
;
}
int
GetSize
()
const
override
{
return
data_
.
size
();
}
const
std
::
vector
<
string
>
GetSerializedData
()
const
override
{
return
data_
;
}
private:
std
::
vector
<
string
>
data_
;
// the batch data
};
// *****************************************************************************
// Tests begin here.
// *****************************************************************************
...
...
@@ -739,7 +771,7 @@ TEST(ComputeSessionImplTest, InitializesComponentWithSource) {
EXPECT_CALL
(
*
mock_components
[
"component_one"
],
GetBeam
())
.
WillOnce
(
Return
(
beam
));
// Expect that the second component will rec
i
eve that beam.
// Expect that the second component will rece
i
ve that beam.
EXPECT_CALL
(
*
mock_components
[
"component_two"
],
InitializeData
(
beam
,
kMaxBeamSize
,
NotNull
()));
...
...
@@ -899,7 +931,7 @@ TEST(ComputeSessionImplTest, SetTracingPropagatesToAllComponents) {
EXPECT_CALL
(
*
mock_components
[
"component_one"
],
GetBeam
())
.
WillOnce
(
Return
(
beam
));
// Expect that the second component will rec
i
eve that beam, and then its
// Expect that the second component will rece
i
ve that beam, and then its
// tracing will be initialized.
EXPECT_CALL
(
*
mock_components
[
"component_two"
],
InitializeData
(
beam
,
kMaxBeamSize
,
NotNull
()));
...
...
@@ -1084,12 +1116,12 @@ TEST(ComputeSessionImplTest, InterfacePassesThrough) {
session
->
AdvanceFromOracle
(
"component_one"
);
// AdvanceFromPrediction()
const
expr
int
k
ScoreMatrixLength
=
3
;
const
float
score_matrix
[
kScoreMatrixLength
]
=
{
1.0
,
2.3
,
4.5
};
const
int
k
NumActions
=
1
;
const
float
score_matrix
[]
=
{
1.0
,
2.3
,
4.5
};
EXPECT_CALL
(
*
mock_components
[
"component_one"
],
AdvanceFromPrediction
(
score_matrix
,
kScoreMatrixLength
));
session
->
AdvanceFromPrediction
(
"component_one"
,
score_matrix
,
k
ScoreMatrixLength
);
AdvanceFromPrediction
(
score_matrix
,
batch_size
,
kNumActions
));
session
->
AdvanceFromPrediction
(
"component_one"
,
score_matrix
,
batch_size
,
k
NumActions
);
// GetFixedFeatures
auto
allocate_indices
=
[](
int
size
)
->
int32
*
{
return
nullptr
;
};
...
...
@@ -1109,6 +1141,11 @@ TEST(ComputeSessionImplTest, InterfacePassesThrough) {
.
WillOnce
(
Return
(
0
));
EXPECT_EQ
(
0
,
session
->
BulkGetInputFeatures
(
"component_one"
,
extractor
));
// BulkEmbedFixedFeatures
EXPECT_CALL
(
*
mock_components
[
"component_one"
],
BulkEmbedFixedFeatures
(
1
,
2
,
3
,
_
,
_
));
session
->
BulkEmbedFixedFeatures
(
"component_one"
,
1
,
2
,
3
,
{
nullptr
},
nullptr
);
// EmitOracleLabels()
std
::
vector
<
std
::
vector
<
int
>>
oracle_labels
=
{{
0
,
1
},
{
2
,
3
}};
EXPECT_CALL
(
*
mock_components
[
"component_one"
],
GetOracleLabels
())
...
...
@@ -1154,7 +1191,7 @@ TEST(ComputeSessionImplTest, InterfaceRequiresReady) {
constexpr
int
kScoreMatrixLength
=
3
;
const
float
score_matrix
[
kScoreMatrixLength
]
=
{
1.0
,
2.3
,
4.5
};
EXPECT_DEATH
(
session
->
AdvanceFromPrediction
(
"component_one"
,
score_matrix
,
kScoreMatrixLength
),
kScoreMatrixLength
,
1
),
"without first initializing it"
);
constexpr
int
kArbitraryChannelId
=
3
;
EXPECT_DEATH
(
session
->
GetInputFeatures
(
"component_one"
,
nullptr
,
nullptr
,
...
...
@@ -1163,10 +1200,32 @@ TEST(ComputeSessionImplTest, InterfaceRequiresReady) {
BulkFeatureExtractor
extractor
(
nullptr
,
nullptr
,
nullptr
,
false
,
0
,
0
);
EXPECT_DEATH
(
session
->
BulkGetInputFeatures
(
"component_one"
,
extractor
),
"without first initializing it"
);
EXPECT_DEATH
(
session
->
BulkEmbedFixedFeatures
(
"component_one"
,
0
,
0
,
0
,
{
nullptr
},
nullptr
),
"without first initializing it"
);
EXPECT_DEATH
(
session
->
GetTranslatedLinkFeatures
(
"component_one"
,
kArbitraryChannelId
),
"without first initializing it"
);
}
TEST
(
ComputeSessionImplTest
,
SetInputBatchCache
)
{
// Use empty protos since we won't interact with components.
MasterSpec
spec
;
GridPoint
hyperparams
;
ComputeSessionPool
pool
(
spec
,
hyperparams
);
auto
session
=
pool
.
GetSession
();
// Initialize a cached IdentityBatch.
const
std
::
vector
<
string
>
data
=
{
"foo"
,
"bar"
,
"baz"
};
std
::
unique_ptr
<
InputBatchCache
>
input_batch_cache
(
new
InputBatchCache
(
data
));
input_batch_cache
->
GetAs
<
IdentityBatch
>
();
// Inject the cache into the session.
session
->
SetInputBatchCache
(
std
::
move
(
input_batch_cache
));
// Check that the injected batch can be retrieved.
EXPECT_EQ
(
session
->
GetSerializedPredictions
(),
data
);
}
}
// namespace dragnn
}
// namespace syntaxnet
research/syntaxnet/dragnn/core/compute_session_pool.h
View file @
4364390a
...
...
@@ -13,8 +13,8 @@
// limitations under the License.
// =============================================================================
#ifndef
NLP_SAFT_OPENSOURCE_
DRAGNN_CORE_COMPUTE_SESSION_POOL_H_
#define
NLP_SAFT_OPENSOURCE_
DRAGNN_CORE_COMPUTE_SESSION_POOL_H_
#ifndef DRAGNN_CORE_COMPUTE_SESSION_POOL_H_
#define DRAGNN_CORE_COMPUTE_SESSION_POOL_H_
#include <memory>
...
...
@@ -29,14 +29,14 @@ namespace dragnn {
class
ComputeSessionPool
{
public:
// Create a ComputeSessionPool that creates ComputeSessions for the given
// Create
s
a ComputeSessionPool that creates ComputeSessions for the given
// MasterSpec and hyperparameters.
ComputeSessionPool
(
const
MasterSpec
&
master_spec
,
const
GridPoint
&
hyperparams
);
virtual
~
ComputeSessionPool
();
// Get a ComputeSession. This function will attempt to use an already-created
// Get
s
a ComputeSession. This function will attempt to use an already-created
// ComputeSession, but if none are available a new one will be created.
std
::
unique_ptr
<
ComputeSession
>
GetSession
();
...
...
@@ -49,6 +49,12 @@ class ComputeSessionPool {
return
num_unique_sessions_
-
sessions_
.
size
();
}
// Returns the number of unique sessions that have been created.
int
num_unique_sessions
()
{
return
num_unique_sessions_
;
}
// Returns a reference to the underlying spec for this pool.
const
MasterSpec
&
GetSpec
()
const
{
return
master_spec_
;
}
private:
friend
class
ComputeSessionImplTestPoolAccessor
;
friend
class
ComputeSessionPoolTestPoolAccessor
;
...
...
@@ -99,4 +105,4 @@ class ComputeSessionPool {
}
// namespace dragnn
}
// namespace syntaxnet
#endif //
NLP_SAFT_OPENSOURCE_
DRAGNN_CORE_COMPUTE_SESSION_POOL_H_
#endif // DRAGNN_CORE_COMPUTE_SESSION_POOL_H_
research/syntaxnet/dragnn/core/compute_session_pool_test.cc
View file @
4364390a
...
...
@@ -207,6 +207,7 @@ TEST(ComputeSessionPoolTest, SupportsMultithreadedAccess) {
std
::
vector
<
std
::
unique_ptr
<
tensorflow
::
Thread
>>
request_threads
;
constexpr
int
kNumThreadsToTest
=
100
;
request_threads
.
reserve
(
kNumThreadsToTest
);
for
(
int
i
=
0
;
i
<
kNumThreadsToTest
;
++
i
)
{
request_threads
.
push_back
(
std
::
unique_ptr
<
tensorflow
::
Thread
>
(
tensorflow
::
Env
::
Default
()
->
StartThread
(
...
...
research/syntaxnet/dragnn/core/index_translator.h
View file @
4364390a
...
...
@@ -13,8 +13,8 @@
// limitations under the License.
// =============================================================================
#ifndef
NLP_SAFT_OPENSOURCE_
DRAGNN_CORE_INDEX_TRANSLATOR_H_
#define
NLP_SAFT_OPENSOURCE_
DRAGNN_CORE_INDEX_TRANSLATOR_H_
#ifndef DRAGNN_CORE_INDEX_TRANSLATOR_H_
#define DRAGNN_CORE_INDEX_TRANSLATOR_H_
#include <memory>
#include <vector>
...
...
@@ -80,4 +80,4 @@ class IndexTranslator {
}
// namespace dragnn
}
// namespace syntaxnet
#endif //
NLP_SAFT_OPENSOURCE_
DRAGNN_CORE_INDEX_TRANSLATOR_H_
#endif // DRAGNN_CORE_INDEX_TRANSLATOR_H_
research/syntaxnet/dragnn/core/input_batch_cache.h
View file @
4364390a
...
...
@@ -13,12 +13,15 @@
// limitations under the License.
// =============================================================================
#ifndef
NLP_SAFT_OPENSOURCE_
DRAGNN_CORE_INPUT_BATCH_CACHE_H_
#define
NLP_SAFT_OPENSOURCE_
DRAGNN_CORE_INPUT_BATCH_CACHE_H_
#ifndef DRAGNN_CORE_INPUT_BATCH_CACHE_H_
#define DRAGNN_CORE_INPUT_BATCH_CACHE_H_
#include <memory>
#include <string>
#include <type_traits>
#include <typeindex>
#include <typeinfo>
#include <utility>
#include "dragnn/core/interfaces/input_batch.h"
#include "tensorflow/core/platform/logging.h"
...
...
@@ -42,6 +45,18 @@ class InputBatchCache {
explicit
InputBatchCache
(
const
std
::
vector
<
string
>
&
data
)
:
stored_type_
(
std
::
type_index
(
typeid
(
void
))),
source_data_
(
data
)
{}
// Creates a InputBatchCache from the |batch|. InputBatchSubclass must be a
// strict subclass of InputBatch, and |batch| must be non-null. All calls to
// GetAs must match InputBatchSubclass.
template
<
class
InputBatchSubclass
>
explicit
InputBatchCache
(
std
::
unique_ptr
<
InputBatchSubclass
>
batch
)
:
stored_type_
(
std
::
type_index
(
typeid
(
InputBatchSubclass
))),
converted_data_
(
std
::
move
(
batch
))
{
static_assert
(
IsStrictInputBatchSubclass
<
InputBatchSubclass
>
(),
"InputBatchCache requires a strict subclass of InputBatch"
);
CHECK
(
converted_data_
)
<<
"Cannot initialize from a null InputBatch"
;
}
// Adds a single string to the cache. Only useable before GetAs() has been
// called.
void
AddData
(
const
string
&
data
)
{
...
...
@@ -52,10 +67,14 @@ class InputBatchCache {
}
// Converts the stored strings into protos and return them in a specific
// InputBatch subclass. T should always be of type InputBatch. After this
// method is called once, all further calls must be of the same data type.
// InputBatch subclass. T should always be a strict subclass of InputBatch.
// After this method is called once, all further calls must be of the same
// data type.
template
<
class
T
>
T
*
GetAs
()
{
static_assert
(
IsStrictInputBatchSubclass
<
T
>
(),
"GetAs<T>() requires that T is a strict subclass of InputBatch"
);
if
(
!
converted_data_
)
{
stored_type_
=
std
::
type_index
(
typeid
(
T
));
converted_data_
.
reset
(
new
T
());
...
...
@@ -69,14 +88,27 @@ class InputBatchCache {
return
dynamic_cast
<
T
*>
(
converted_data_
.
get
());
}
// Returns the size of the batch. Requires that GetAs() has been called.
int
Size
()
const
{
CHECK
(
converted_data_
)
<<
"Cannot return batch size without data."
;
return
converted_data_
->
GetSize
();
}
// Returns the serialized representation of the data held in the input batch
// object within this cache.
// object within this cache.
Requires that GetAs() has been called.
const
std
::
vector
<
string
>
SerializedData
()
const
{
CHECK
(
converted_data_
)
<<
"Cannot return batch without data."
;
return
converted_data_
->
GetSerializedData
();
}
private:
// Returns true if InputBatchSubclass is a strict subclass of InputBatch.
template
<
class
InputBatchSubclass
>
static
constexpr
bool
IsStrictInputBatchSubclass
()
{
return
std
::
is_base_of
<
InputBatch
,
InputBatchSubclass
>::
value
&&
!
std
::
is_same
<
InputBatch
,
InputBatchSubclass
>::
value
;
}
// The typeid of the stored data.
std
::
type_index
stored_type_
;
...
...
@@ -90,4 +122,4 @@ class InputBatchCache {
}
// namespace dragnn
}
// namespace syntaxnet
#endif //
NLP_SAFT_OPENSOURCE_
DRAGNN_CORE_INPUT_BATCH_CACHE_H_
#endif // DRAGNN_CORE_INPUT_BATCH_CACHE_H_
research/syntaxnet/dragnn/core/input_batch_cache_test.cc
View file @
4364390a
...
...
@@ -32,6 +32,8 @@ class StringData : public InputBatch {
}
}
int
GetSize
()
const
override
{
return
data_
.
size
();
}
const
std
::
vector
<
string
>
GetSerializedData
()
const
override
{
return
data_
;
}
std
::
vector
<
string
>
*
data
()
{
return
&
data_
;
}
...
...
@@ -50,6 +52,8 @@ class DifferentStringData : public InputBatch {
}
}
int
GetSize
()
const
override
{
return
data_
.
size
();
}
const
std
::
vector
<
string
>
GetSerializedData
()
const
override
{
return
data_
;
}
std
::
vector
<
string
>
*
data
()
{
return
&
data_
;
}
...
...
@@ -58,6 +62,11 @@ class DifferentStringData : public InputBatch {
std
::
vector
<
string
>
data_
;
};
// Expects that two pointers have the same address.
void
ExpectSameAddress
(
const
void
*
pointer1
,
const
void
*
pointer2
)
{
EXPECT_EQ
(
pointer1
,
pointer2
);
}
TEST
(
InputBatchCacheTest
,
ConvertsSingleInput
)
{
string
test_string
=
"Foo"
;
InputBatchCache
generic_set
(
test_string
);
...
...
@@ -118,5 +127,48 @@ TEST(InputBatchCacheTest, ConvertsAddedInputDiesAfterGetAs) {
"after the cache has been converted"
);
}
TEST
(
InputBatchCacheTest
,
SerializedDataAndSize
)
{
InputBatchCache
generic_set
;
generic_set
.
AddData
(
"Foo"
);
generic_set
.
AddData
(
"Bar"
);
generic_set
.
GetAs
<
StringData
>
();
const
std
::
vector
<
string
>
expected_data
=
{
"Foo_converted"
,
"Bar_converted"
};
EXPECT_EQ
(
expected_data
,
generic_set
.
SerializedData
());
EXPECT_EQ
(
2
,
generic_set
.
Size
());
}
TEST
(
InputBatchCacheTest
,
InitializeFromInputBatch
)
{
const
std
::
vector
<
string
>
kInputData
=
{
"foo"
,
"bar"
,
"baz"
};
const
std
::
vector
<
string
>
kExpectedData
=
{
"foo_converted"
,
//
"bar_converted"
,
//
"baz_converted"
};
std
::
unique_ptr
<
StringData
>
string_data
(
new
StringData
());
string_data
->
SetData
(
kInputData
);
const
StringData
*
string_data_ptr
=
string_data
.
get
();
InputBatchCache
generic_set
(
std
::
move
(
string_data
));
auto
data
=
generic_set
.
GetAs
<
StringData
>
();
ExpectSameAddress
(
string_data_ptr
,
data
);
EXPECT_EQ
(
data
->
GetSize
(),
3
);
EXPECT_EQ
(
data
->
GetSerializedData
(),
kExpectedData
);
EXPECT_EQ
(
*
data
->
data
(),
kExpectedData
);
// AddData() shouldn't work since the cache is already populated.
EXPECT_DEATH
(
generic_set
.
AddData
(
"YOU MAY NOT DO THIS AND IT WILL DIE."
),
"after the cache has been converted"
);
// GetAs() shouldn't work with a different type.
EXPECT_DEATH
(
generic_set
.
GetAs
<
DifferentStringData
>
(),
"Attempted to convert to two object types!"
);
}
TEST
(
InputBatchCacheTest
,
CannotInitializeFromNullInputBatch
)
{
EXPECT_DEATH
(
InputBatchCache
(
std
::
unique_ptr
<
StringData
>
()),
"Cannot initialize from a null InputBatch"
);
}
}
// namespace dragnn
}
// namespace syntaxnet
research/syntaxnet/dragnn/core/interfaces/cloneable_transition_state.h
View file @
4364390a
...
...
@@ -13,8 +13,8 @@
// limitations under the License.
// =============================================================================
#ifndef
NLP_SAFT_OPENSOURCE_
DRAGNN_CORE_INTERFACES_CLONEABLE_TRANSITION_STATE_H_
#define
NLP_SAFT_OPENSOURCE_
DRAGNN_CORE_INTERFACES_CLONEABLE_TRANSITION_STATE_H_
#ifndef DRAGNN_CORE_INTERFACES_CLONEABLE_TRANSITION_STATE_H_
#define DRAGNN_CORE_INTERFACES_CLONEABLE_TRANSITION_STATE_H_
#include <memory>
#include <vector>
...
...
@@ -33,26 +33,32 @@ class CloneableTransitionState : public TransitionState {
public:
~
CloneableTransitionState
<
T
>
()
override
{}
// Initialize this TransitionState from a previous TransitionState. The
// Initialize
s
this TransitionState from a previous TransitionState. The
// ParentBeamIndex is the location of that previous TransitionState in the
// provided beam.
void
Init
(
const
TransitionState
&
parent
)
override
=
0
;
// Return the beam index of the state passed into the initializer of this
// Return
s
the beam index of the state passed into the initializer of this
// TransitionState.
const
int
ParentBeamIndex
()
const
override
=
0
;
int
ParentBeamIndex
()
const
override
=
0
;
// Get the current beam index for this state.
const
int
GetBeamIndex
()
const
override
=
0
;
// Get
s
the current beam index for this state.
int
GetBeamIndex
()
const
override
=
0
;
// Set the current beam index for this state.
void
SetBeamIndex
(
const
int
index
)
override
=
0
;
// Set
s
the current beam index for this state.
void
SetBeamIndex
(
int
index
)
override
=
0
;
// Get the score associated with this transition state.
const
float
GetScore
()
const
override
=
0
;
// Get
s
the score associated with this transition state.
float
GetScore
()
const
override
=
0
;
// Set the score associated with this transition state.
void
SetScore
(
const
float
score
)
override
=
0
;
// Sets the score associated with this transition state.
void
SetScore
(
float
score
)
override
=
0
;
// Gets the gold-ness of this state (whether it is on the oracle path)
bool
IsGold
()
const
override
=
0
;
// Sets the gold-ness of this state.
void
SetGold
(
bool
is_gold
)
override
=
0
;
// Depicts this state as an HTML-language string.
string
HTMLRepresentation
()
const
override
=
0
;
...
...
@@ -64,4 +70,4 @@ class CloneableTransitionState : public TransitionState {
}
// namespace dragnn
}
// namespace syntaxnet
#endif //
NLP_SAFT_OPENSOURCE_
DRAGNN_CORE_INTERFACES_CLONEABLE_TRANSITION_STATE_H_
#endif // DRAGNN_CORE_INTERFACES_CLONEABLE_TRANSITION_STATE_H_
research/syntaxnet/dragnn/core/interfaces/component.h
View file @
4364390a
...
...
@@ -13,8 +13,8 @@
// limitations under the License.
// =============================================================================
#ifndef
NLP_SAFT_OPENSOURCE_
DRAGNN_CORE_INTERFACES_COMPONENT_H_
#define
NLP_SAFT_OPENSOURCE_
DRAGNN_CORE_INTERFACES_COMPONENT_H_
#ifndef DRAGNN_CORE_INTERFACES_COMPONENT_H_
#define DRAGNN_CORE_INTERFACES_COMPONENT_H_
#include <vector>
...
...
@@ -83,11 +83,13 @@ class Component : public RegisterableClass<Component> {
virtual
std
::
function
<
int
(
int
,
int
,
int
)
>
GetStepLookupFunction
(
const
string
&
method
)
=
0
;
// Advances this component from the given transition matrix.
virtual
void
AdvanceFromPrediction
(
const
float
transition_matrix
[],
int
transition_matrix_length
)
=
0
;
// Advances this component from the given transition matrix, which is
// |num_items| x |num_actions|.
virtual
bool
AdvanceFromPrediction
(
const
float
*
score_matrix
,
int
num_items
,
int
num_actions
)
=
0
;
// Advances this component from the state oracles.
// Advances this component from the state oracles. There is no return from
// this, since it should always succeed.
virtual
void
AdvanceFromOracle
()
=
0
;
// Returns true if all states within this component are terminal.
...
...
@@ -110,6 +112,14 @@ class Component : public RegisterableClass<Component> {
// BulkFeatureExtractor object to contain the functors and other information.
virtual
int
BulkGetFixedFeatures
(
const
BulkFeatureExtractor
&
extractor
)
=
0
;
// Directly computes the embedding matrix for all channels, advancing the
// component via the oracle until it is terminal. This call takes a vector
// of EmbeddingMatrix structs, one per channel, in channel order.
virtual
void
BulkEmbedFixedFeatures
(
int
batch_size_padding
,
int
num_steps_padding
,
int
output_array_size
,
const
vector
<
const
float
*>
&
per_channel_embeddings
,
float
*
embedding_output
)
=
0
;
// Extracts and returns the vector of LinkFeatures for the specified
// channel. Note: these are NOT translated.
virtual
std
::
vector
<
LinkFeatures
>
GetRawLinkFeatures
(
...
...
@@ -138,4 +148,4 @@ class Component : public RegisterableClass<Component> {
}
// namespace dragnn
}
// namespace syntaxnet
#endif //
NLP_SAFT_OPENSOURCE_
DRAGNN_CORE_INTERFACES_COMPONENT_H_
#endif // DRAGNN_CORE_INTERFACES_COMPONENT_H_
research/syntaxnet/dragnn/core/interfaces/input_batch.h
View file @
4364390a
...
...
@@ -13,8 +13,8 @@
// limitations under the License.
// =============================================================================
#ifndef
NLP_SAFT_OPENSOURCE_
DRAGNN_CORE_INTERFACES_INPUT_BATCH_H_
#define
NLP_SAFT_OPENSOURCE_
DRAGNN_CORE_INTERFACES_INPUT_BATCH_H_
#ifndef DRAGNN_CORE_INTERFACES_INPUT_BATCH_H_
#define DRAGNN_CORE_INTERFACES_INPUT_BATCH_H_
#include <string>
#include <vector>
...
...
@@ -32,14 +32,17 @@ class InputBatch {
public:
virtual
~
InputBatch
()
{}
// Set the data to translate to the subclass' data type.
// Set
s
the data to translate to the subclass' data type.
Call at most once.
virtual
void
SetData
(
const
std
::
vector
<
string
>
&
data
)
=
0
;
// Translate the underlying data back to a vector of strings, as appropriate.
// Returns the size of the batch.
virtual
int
GetSize
()
const
=
0
;
// Translates the underlying data back to a vector of strings, as appropriate.
virtual
const
std
::
vector
<
string
>
GetSerializedData
()
const
=
0
;
};
}
// namespace dragnn
}
// namespace syntaxnet
#endif //
NLP_SAFT_OPENSOURCE_
DRAGNN_CORE_INTERFACES_INPUT_BATCH_H_
#endif // DRAGNN_CORE_INTERFACES_INPUT_BATCH_H_
research/syntaxnet/dragnn/core/interfaces/transition_state.h
View file @
4364390a
...
...
@@ -13,8 +13,8 @@
// limitations under the License.
// =============================================================================
#ifndef
NLP_SAFT_OPENSOURCE_
DRAGNN_CORE_INTERFACES_TRANSITION_STATE_H_
#define
NLP_SAFT_OPENSOURCE_
DRAGNN_CORE_INTERFACES_TRANSITION_STATE_H_
#ifndef DRAGNN_CORE_INTERFACES_TRANSITION_STATE_H_
#define DRAGNN_CORE_INTERFACES_TRANSITION_STATE_H_
#include <memory>
#include <vector>
...
...
@@ -44,19 +44,25 @@ class TransitionState {
// Return the beam index of the state passed into the initializer of this
// TransitionState.
virtual
const
int
ParentBeamIndex
()
const
=
0
;
virtual
int
ParentBeamIndex
()
const
=
0
;
// Get the current beam index for this state.
virtual
const
int
GetBeamIndex
()
const
=
0
;
// Get
s
the current beam index for this state.
virtual
int
GetBeamIndex
()
const
=
0
;
// Set the current beam index for this state.
virtual
void
SetBeamIndex
(
const
int
index
)
=
0
;
// Set
s
the current beam index for this state.
virtual
void
SetBeamIndex
(
int
index
)
=
0
;
// Get the score associated with this transition state.
virtual
const
float
GetScore
()
const
=
0
;
// Get
s
the score associated with this transition state.
virtual
float
GetScore
()
const
=
0
;
// Set the score associated with this transition state.
virtual
void
SetScore
(
const
float
score
)
=
0
;
// Sets the score associated with this transition state.
virtual
void
SetScore
(
float
score
)
=
0
;
// Gets the gold-ness of this state (whether it is on the oracle path)
virtual
bool
IsGold
()
const
=
0
;
// Sets the gold-ness of this state.
virtual
void
SetGold
(
bool
is_gold
)
=
0
;
// Depicts this state as an HTML-language string.
virtual
string
HTMLRepresentation
()
const
=
0
;
...
...
@@ -65,4 +71,4 @@ class TransitionState {
}
// namespace dragnn
}
// namespace syntaxnet
#endif //
NLP_SAFT_OPENSOURCE_
DRAGNN_CORE_INTERFACES_TRANSITION_STATE_H_
#endif // DRAGNN_CORE_INTERFACES_TRANSITION_STATE_H_
research/syntaxnet/dragnn/core/ops/compute_session_op.h
View file @
4364390a
...
...
@@ -13,8 +13,8 @@
// limitations under the License.
// =============================================================================
#ifndef
NLP_SAFT_OPENSOURCE_
DRAGNN_CORE_OPS_COMPUTE_SESSION_OP_H_
#define
NLP_SAFT_OPENSOURCE_
DRAGNN_CORE_OPS_COMPUTE_SESSION_OP_H_
#ifndef DRAGNN_CORE_OPS_COMPUTE_SESSION_OP_H_
#define DRAGNN_CORE_OPS_COMPUTE_SESSION_OP_H_
#include <string>
...
...
@@ -66,4 +66,4 @@ class ComputeSessionOp : public tensorflow::OpKernel {
}
// namespace dragnn
}
// namespace syntaxnet
#endif //
NLP_SAFT_OPENSOURCE_
DRAGNN_CORE_OPS_COMPUTE_SESSION_OP_H_
#endif // DRAGNN_CORE_OPS_COMPUTE_SESSION_OP_H_
research/syntaxnet/dragnn/core/ops/dragnn_bulk_op_kernels.cc
View file @
4364390a
...
...
@@ -303,6 +303,73 @@ class BulkFixedEmbeddings : public ComputeSessionOp {
REGISTER_KERNEL_BUILDER
(
Name
(
"BulkFixedEmbeddings"
).
Device
(
DEVICE_CPU
),
BulkFixedEmbeddings
);
// See docstring in dragnn_bulk_ops.cc.
class
BulkEmbedFixedFeatures
:
public
ComputeSessionOp
{
public:
explicit
BulkEmbedFixedFeatures
(
OpKernelConstruction
*
context
)
:
ComputeSessionOp
(
context
)
{
OP_REQUIRES_OK
(
context
,
context
->
GetAttr
(
"num_channels"
,
&
num_channels_
));
// The input vector's zeroth element is the state handle, and the remaining
// num_channels_ elements are tensors of float embeddings, one per channel.
vector
<
DataType
>
input_types
(
num_channels_
+
1
,
DT_FLOAT
);
input_types
[
0
]
=
DT_STRING
;
const
vector
<
DataType
>
output_types
=
{
DT_STRING
,
DT_FLOAT
,
DT_INT32
};
OP_REQUIRES_OK
(
context
,
context
->
MatchSignature
(
input_types
,
output_types
));
OP_REQUIRES_OK
(
context
,
context
->
GetAttr
(
"pad_to_batch"
,
&
pad_to_batch_
));
OP_REQUIRES_OK
(
context
,
context
->
GetAttr
(
"pad_to_steps"
,
&
pad_to_steps_
));
}
bool
OutputsHandle
()
const
override
{
return
true
;
}
bool
RequiresComponentName
()
const
override
{
return
true
;
}
void
ComputeWithState
(
OpKernelContext
*
context
,
ComputeSession
*
session
)
override
{
const
auto
&
spec
=
session
->
Spec
(
component_name
());
int
embedding_size
=
0
;
std
::
vector
<
const
float
*>
embeddings
(
num_channels_
);
for
(
int
channel
=
0
;
channel
<
num_channels_
;
++
channel
)
{
const
int
embeddings_index
=
channel
+
1
;
embedding_size
+=
context
->
input
(
embeddings_index
).
shape
().
dim_size
(
1
)
*
spec
.
fixed_feature
(
channel
).
size
();
embeddings
[
channel
]
=
context
->
input
(
embeddings_index
).
flat
<
float
>
().
data
();
}
Tensor
*
embedding_vectors
;
OP_REQUIRES_OK
(
context
,
context
->
allocate_output
(
1
,
TensorShape
({
pad_to_steps_
*
pad_to_batch_
*
session
->
BeamSize
(
component_name
()),
embedding_size
}),
&
embedding_vectors
));
Tensor
*
num_steps_tensor
;
OP_REQUIRES_OK
(
context
,
context
->
allocate_output
(
2
,
TensorShape
({}),
&
num_steps_tensor
));
embedding_vectors
->
flat
<
float
>
().
setZero
();
int
output_size
=
embedding_vectors
->
NumElements
();
session
->
BulkEmbedFixedFeatures
(
component_name
(),
pad_to_batch_
,
pad_to_steps_
,
output_size
,
embeddings
,
embedding_vectors
->
flat
<
float
>
().
data
());
num_steps_tensor
->
scalar
<
int32
>
()()
=
pad_to_steps_
;
}
private:
// Number of fixed feature channels.
int
num_channels_
;
// Will pad output to this many batch elements.
int
pad_to_batch_
;
// Will pad output to this many steps.
int
pad_to_steps_
;
TF_DISALLOW_COPY_AND_ASSIGN
(
BulkEmbedFixedFeatures
);
};
REGISTER_KERNEL_BUILDER
(
Name
(
"BulkEmbedFixedFeatures"
).
Device
(
DEVICE_CPU
),
BulkEmbedFixedFeatures
);
// See docstring in dragnn_bulk_ops.cc.
class
BulkAdvanceFromOracle
:
public
ComputeSessionOp
{
public:
...
...
@@ -387,8 +454,11 @@ class BulkAdvanceFromPrediction : public ComputeSessionOp {
}
}
if
(
!
session
->
IsTerminal
(
component_name
()))
{
session
->
AdvanceFromPrediction
(
component_name
(),
scores_per_step
.
data
(),
scores_per_step
.
size
());
bool
success
=
session
->
AdvanceFromPrediction
(
component_name
(),
scores_per_step
.
data
(),
num_items
,
num_actions
);
OP_REQUIRES
(
context
,
success
,
tensorflow
::
errors
::
Internal
(
"Unable to advance from prediction."
));
}
}
}
...
...
research/syntaxnet/dragnn/core/ops/dragnn_bulk_op_kernels_test.cc
View file @
4364390a
...
...
@@ -375,6 +375,114 @@ TEST_F(DragnnBulkOpKernelsTest, BulkFixedEmbeddings) {
EXPECT_EQ
(
kNumSteps
,
GetOutput
(
2
)
->
scalar
<
int32
>
()());
}
TEST_F
(
DragnnBulkOpKernelsTest
,
BulkEmbedFixedFeatures
)
{
// Create and initialize the kernel under test.
constexpr
int
kBatchPad
=
7
;
constexpr
int
kStepPad
=
5
;
constexpr
int
kMaxBeamSize
=
3
;
TF_ASSERT_OK
(
NodeDefBuilder
(
"BulkEmbedFixedFeatures"
,
"BulkEmbedFixedFeatures"
)
.
Attr
(
"component"
,
kComponentName
)
.
Attr
(
"num_channels"
,
kNumChannels
)
.
Attr
(
"pad_to_batch"
,
kBatchPad
)
.
Attr
(
"pad_to_steps"
,
kStepPad
)
.
Input
(
FakeInput
(
DT_STRING
))
// The handle for the ComputeSession.
.
Input
(
FakeInput
(
DT_FLOAT
))
// Embedding matrices.
.
Finalize
(
node_def
()));
MockComputeSession
*
mock_session
=
GetMockSession
();
ComponentSpec
spec
;
spec
.
set_name
(
kComponentName
);
auto
chan0_spec
=
spec
.
add_fixed_feature
();
constexpr
int
kChan0FeatureCount
=
2
;
chan0_spec
->
set_size
(
kChan0FeatureCount
);
auto
chan1_spec
=
spec
.
add_fixed_feature
();
constexpr
int
kChan1FeatureCount
=
1
;
chan1_spec
->
set_size
(
kChan1FeatureCount
);
EXPECT_CALL
(
*
mock_session
,
Spec
(
kComponentName
))
.
WillOnce
(
testing
::
ReturnRef
(
spec
));
EXPECT_CALL
(
*
mock_session
,
BeamSize
(
kComponentName
))
.
WillOnce
(
testing
::
Return
(
kMaxBeamSize
));
// Embedding matrices as additional inputs.
// For channel 0, the embeddings are [id, id, id].
// For channel 1, the embeddings are [id^2, id^2, id^2, ... ,id^2].
vector
<
float
>
embedding_matrix_0
;
constexpr
int
kEmbedding0Size
=
3
;
vector
<
float
>
embedding_matrix_1
;
constexpr
int
kEmbedding1Size
=
9
;
for
(
int
id
=
0
;
id
<
kNumIds
;
++
id
)
{
for
(
int
i
=
0
;
i
<
kEmbedding0Size
;
++
i
)
{
embedding_matrix_0
.
push_back
(
id
);
LOG
(
INFO
)
<<
embedding_matrix_0
.
back
();
}
for
(
int
i
=
0
;
i
<
kEmbedding1Size
;
++
i
)
{
embedding_matrix_1
.
push_back
(
id
*
id
);
LOG
(
INFO
)
<<
embedding_matrix_0
.
back
();
}
}
AddInputFromArray
<
float
>
(
TensorShape
({
kNumIds
,
kEmbedding0Size
}),
embedding_matrix_0
);
AddInputFromArray
<
float
>
(
TensorShape
({
kNumIds
,
kEmbedding1Size
}),
embedding_matrix_1
);
constexpr
int
kExpectedEmbeddingSize
=
kChan0FeatureCount
*
kEmbedding0Size
+
kChan1FeatureCount
*
kEmbedding1Size
;
constexpr
int
kExpectedOutputSize
=
kExpectedEmbeddingSize
*
kBatchPad
*
kStepPad
*
kMaxBeamSize
;
// This function takes the allocator functions passed into GetBulkFF, uses
// them to allocate a tensor, then fills that tensor based on channel.
auto
eval_function
=
[
=
](
const
string
&
component_name
,
int
batch_size_padding
,
int
num_steps_padding
,
int
output_array_size
,
const
vector
<
const
float
*>
&
per_channel_embeddings
,
float
*
embedding_output
)
{
// Validate the control variables.
EXPECT_EQ
(
batch_size_padding
,
kBatchPad
);
EXPECT_EQ
(
num_steps_padding
,
kStepPad
);
EXPECT_EQ
(
output_array_size
,
kExpectedOutputSize
);
// Validate the passed embeddings.
for
(
int
i
=
0
;
i
<
kNumIds
;
++
i
)
{
for
(
int
j
=
0
;
j
<
kEmbedding0Size
;
++
j
)
{
float
ch0_embedding
=
per_channel_embeddings
.
at
(
0
)[
i
*
kEmbedding0Size
+
j
];
EXPECT_FLOAT_EQ
(
ch0_embedding
,
i
)
<<
"Failed match at "
<<
i
<<
","
<<
j
;
}
for
(
int
j
=
0
;
j
<
kEmbedding1Size
;
++
j
)
{
float
ch1_embedding
=
per_channel_embeddings
.
at
(
1
)[
i
*
kEmbedding1Size
+
j
];
EXPECT_FLOAT_EQ
(
ch1_embedding
,
i
*
i
)
<<
"Failed match at "
<<
i
<<
","
<<
j
;
}
}
// Fill the output matrix to the expected size. This will trigger msan
// if the allocation wasn't big enough.
for
(
int
i
=
0
;
i
<
kExpectedOutputSize
;
++
i
)
{
embedding_output
[
i
]
=
i
;
}
};
EXPECT_CALL
(
*
mock_session
,
BulkEmbedFixedFeatures
(
kComponentName
,
_
,
_
,
_
,
_
,
_
))
.
WillOnce
(
testing
::
Invoke
(
eval_function
));
// Run the kernel.
TF_EXPECT_OK
(
RunOpKernelWithContext
());
// Validate outputs.
EXPECT_EQ
(
kBatchPad
*
kStepPad
*
kMaxBeamSize
,
GetOutput
(
1
)
->
shape
().
dim_size
(
0
));
EXPECT_EQ
(
kExpectedEmbeddingSize
,
GetOutput
(
1
)
->
shape
().
dim_size
(
1
));
auto
output_data
=
GetOutput
(
1
)
->
flat
<
float
>
();
for
(
int
i
=
0
;
i
<
kExpectedOutputSize
;
++
i
)
{
EXPECT_FLOAT_EQ
(
i
,
output_data
(
i
));
}
EXPECT_EQ
(
kStepPad
,
GetOutput
(
2
)
->
scalar
<
int32
>
()());
}
TEST_F
(
DragnnBulkOpKernelsTest
,
BulkFixedEmbeddingsWithPadding
)
{
// Create and initialize the kernel under test.
constexpr
int
kPaddedNumSteps
=
5
;
...
...
@@ -592,12 +700,54 @@ TEST_F(DragnnBulkOpKernelsTest, BulkAdvanceFromPrediction) {
EXPECT_CALL
(
*
mock_session
,
AdvanceFromPrediction
(
kComponentName
,
CheckScoresAreConsecutiveIntegersDivTen
(),
kNumItems
*
kNumActions
))
.
Times
(
kNumSteps
);
kNumItems
,
kNumActions
))
.
Times
(
kNumSteps
)
.
WillRepeatedly
(
Return
(
true
));
// Run the kernel.
TF_EXPECT_OK
(
RunOpKernelWithContext
());
}
TEST_F
(
DragnnBulkOpKernelsTest
,
BulkAdvanceFromPredictionFailsIfAdvanceFails
)
{
// Create and initialize the kernel under test.
TF_ASSERT_OK
(
NodeDefBuilder
(
"BulkAdvanceFromPrediction"
,
"BulkAdvanceFromPrediction"
)
.
Attr
(
"component"
,
kComponentName
)
.
Input
(
FakeInput
(
DT_STRING
))
// The handle for the ComputeSession.
.
Input
(
FakeInput
(
DT_FLOAT
))
// Prediction scores for advancing.
.
Finalize
(
node_def
()));
MockComputeSession
*
mock_session
=
GetMockSession
();
// Creates an input tensor such that each step will see a list of consecutive
// integers divided by 10 as scores.
vector
<
float
>
scores
(
kNumItems
*
kNumSteps
*
kNumActions
);
for
(
int
step
(
0
),
cnt
(
0
);
step
<
kNumSteps
;
++
step
)
{
for
(
int
item
=
0
;
item
<
kNumItems
;
++
item
)
{
for
(
int
action
=
0
;
action
<
kNumActions
;
++
action
,
++
cnt
)
{
scores
[
action
+
kNumActions
*
(
step
+
item
*
kNumSteps
)]
=
cnt
/
10.0
f
;
}
}
}
AddInputFromArray
<
float
>
(
TensorShape
({
kNumItems
*
kNumSteps
,
kNumActions
}),
scores
);
EXPECT_CALL
(
*
mock_session
,
BeamSize
(
kComponentName
)).
WillOnce
(
Return
(
1
));
EXPECT_CALL
(
*
mock_session
,
BatchSize
(
kComponentName
))
.
WillOnce
(
Return
(
kNumItems
));
EXPECT_CALL
(
*
mock_session
,
IsTerminal
(
kComponentName
))
.
Times
(
2
)
.
WillRepeatedly
(
Return
(
false
));
EXPECT_CALL
(
*
mock_session
,
AdvanceFromPrediction
(
kComponentName
,
CheckScoresAreConsecutiveIntegersDivTen
(),
kNumItems
,
kNumActions
))
.
WillOnce
(
Return
(
true
))
.
WillOnce
(
Return
(
false
));
// Run the kernel.
auto
result
=
RunOpKernelWithContext
();
EXPECT_FALSE
(
result
.
ok
());
}
}
// namespace dragnn
}
// namespace syntaxnet
Prev
1
2
3
4
5
6
…
9
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment