Commit 4364390a authored by Ivan Bogatyy's avatar Ivan Bogatyy Committed by calberti
Browse files

Release DRAGNN bulk networks (#2785)

* Release DRAGNN bulk networks
parent 638fd759
...@@ -33,8 +33,9 @@ cc_library( ...@@ -33,8 +33,9 @@ cc_library(
name = "compute_session", name = "compute_session",
hdrs = ["compute_session.h"], hdrs = ["compute_session.h"],
deps = [ deps = [
":index_translator",
":input_batch_cache",
"//dragnn/components/util:bulk_feature_extractor", "//dragnn/components/util:bulk_feature_extractor",
"//dragnn/core:index_translator",
"//dragnn/core/interfaces:component", "//dragnn/core/interfaces:component",
"//dragnn/protos:spec_proto", "//dragnn/protos:spec_proto",
"//dragnn/protos:trace_proto", "//dragnn/protos:trace_proto",
...@@ -120,8 +121,10 @@ cc_test( ...@@ -120,8 +121,10 @@ cc_test(
":compute_session", ":compute_session",
":compute_session_impl", ":compute_session_impl",
":compute_session_pool", ":compute_session_pool",
":input_batch_cache",
"//dragnn/components/util:bulk_feature_extractor", "//dragnn/components/util:bulk_feature_extractor",
"//dragnn/core/interfaces:component", "//dragnn/core/interfaces:component",
"//dragnn/core/interfaces:input_batch",
"//dragnn/core/test:generic", "//dragnn/core/test:generic",
"//dragnn/core/test:mock_component", "//dragnn/core/test:mock_component",
"//dragnn/core/test:mock_transition_state", "//dragnn/core/test:mock_transition_state",
...@@ -248,6 +251,7 @@ cc_library( ...@@ -248,6 +251,7 @@ cc_library(
"//syntaxnet:base", "//syntaxnet:base",
"@org_tensorflow//third_party/eigen3", "@org_tensorflow//third_party/eigen3",
], ],
alwayslink = 1,
) )
# Tensorflow kernel libraries, for use with unit tests. # Tensorflow kernel libraries, for use with unit tests.
......
...@@ -13,8 +13,8 @@ ...@@ -13,8 +13,8 @@
// limitations under the License. // limitations under the License.
// ============================================================================= // =============================================================================
#ifndef NLP_SAFT_OPENSOURCE_DRAGNN_CORE_BEAM_H_ #ifndef DRAGNN_CORE_BEAM_H_
#define NLP_SAFT_OPENSOURCE_DRAGNN_CORE_BEAM_H_ #define DRAGNN_CORE_BEAM_H_
#include <algorithm> #include <algorithm>
#include <cmath> #include <cmath>
...@@ -43,19 +43,23 @@ class Beam { ...@@ -43,19 +43,23 @@ class Beam {
static_assert( static_assert(
std::is_base_of<CloneableTransitionState<T>, T>::value, std::is_base_of<CloneableTransitionState<T>, T>::value,
"This class must be instantiated to use a CloneableTransitionState"); "This class must be instantiated to use a CloneableTransitionState");
track_gold_ = false;
} }
// Sets whether or not the beam should track gold states.
void SetGoldTracking(bool track_gold) { track_gold_ = track_gold; }
// Sets the Beam functions, as follows: // Sets the Beam functions, as follows:
// bool is_allowed(TransitionState *, int): Return true if transition 'int' is // bool is_allowed(TransitionState *, int): Return true if transition 'int' is
// allowed for transition state 'TransitionState *'. // allowed for transition state 'TransitionState *'.
// void perform_transition(TransitionState *, int): Performs transition 'int' // void perform_transition(TransitionState *, int): Performs transition 'int'
// on transition state 'TransitionState *'. // on transition state 'TransitionState *'.
// int oracle_function(TransitionState *): Returns the oracle-specified action // vector<int> oracle_function(TransitionState *): Returns the oracle-
// for transition state 'TransitionState *'. // specified actions for transition state 'TransitionState *'.
void SetFunctions(std::function<bool(T *, int)> is_allowed, void SetFunctions(std::function<bool(T *, int)> is_allowed,
std::function<bool(T *)> is_final, std::function<bool(T *)> is_final,
std::function<void(T *, int)> perform_transition, std::function<void(T *, int)> perform_transition,
std::function<int(T *)> oracle_function) { std::function<vector<int>(T *)> oracle_function) {
is_allowed_ = is_allowed; is_allowed_ = is_allowed;
is_final_ = is_final; is_final_ = is_final;
perform_transition_ = perform_transition; perform_transition_ = perform_transition;
...@@ -74,12 +78,17 @@ class Beam { ...@@ -74,12 +78,17 @@ class Beam {
for (int i = 0; i < beam_.size(); ++i) { for (int i = 0; i < beam_.size(); ++i) {
previous_beam_indices.at(i) = beam_[i]->ParentBeamIndex(); previous_beam_indices.at(i) = beam_[i]->ParentBeamIndex();
beam_[i]->SetBeamIndex(i); beam_[i]->SetBeamIndex(i);
// TODO(googleuser): Add gold tracking to component-level state creation.
if (!track_gold_) {
beam_[i]->SetGold(false);
}
} }
beam_index_history_.emplace_back(previous_beam_indices); beam_index_history_.emplace_back(previous_beam_indices);
} }
// Advances the Beam from the given transition matrix. // Advances the Beam from the given transition matrix.
void AdvanceFromPrediction(const float transition_matrix[], int matrix_length, bool AdvanceFromPrediction(const float *transition_matrix, int matrix_length,
int num_actions) { int num_actions) {
// Ensure that the transition matrix is the correct size. All underlying // Ensure that the transition matrix is the correct size. All underlying
// states should have the same transition profile, so using the one at 0 // states should have the same transition profile, so using the one at 0
...@@ -89,91 +98,20 @@ class Beam { ...@@ -89,91 +98,20 @@ class Beam {
"state transitions!"; "state transitions!";
if (max_size_ == 1) { if (max_size_ == 1) {
// In the case where beam size is 1, we can advance by simply finding the bool success = FastAdvanceFromPrediction(transition_matrix, num_actions);
// highest score and advancing the beam state in place. if (!success) {
VLOG(2) << "Beam size is 1. Using fast beam path."; return false;
int best_action = -1;
float best_score = -INFINITY;
auto &state = beam_[0];
for (int action_idx = 0; action_idx < num_actions; ++action_idx) {
if (is_allowed_(state.get(), action_idx) &&
transition_matrix[action_idx] > best_score) {
best_score = transition_matrix[action_idx];
best_action = action_idx;
}
} }
CHECK_GE(best_action, 0) << "Num actions: " << num_actions
<< " score[0]: " << transition_matrix[0];
perform_transition_(state.get(), best_action);
const float new_score = state->GetScore() + best_score;
state->SetScore(new_score);
state->SetBeamIndex(0);
} else { } else {
// Create the vector of all possible transitions, along with their scores. bool success = BeamAdvanceFromPrediction(transition_matrix, matrix_length,
std::vector<Transition> candidates; num_actions);
if (!success) {
// Iterate through all beams, examining all actions for each beam. return false;
for (int beam_idx = 0; beam_idx < beam_.size(); ++beam_idx) {
const auto &state = beam_[beam_idx];
for (int action_idx = 0; action_idx < num_actions; ++action_idx) {
// If the action is allowed, calculate the proposed new score and add
// the candidate action to the vector of all actions at this state.
if (is_allowed_(state.get(), action_idx)) {
Transition candidate;
// The matrix is laid out by beam index, with a linear set of
// actions for that index - so beam N's actions start at [nr. of
// actions]*[N].
const int matrix_idx = action_idx + beam_idx * num_actions;
CHECK_LT(matrix_idx, matrix_length)
<< "Matrix index out of bounds!";
const double score_delta = transition_matrix[matrix_idx];
CHECK(!std::isnan(score_delta));
candidate.source_idx = beam_idx;
candidate.action = action_idx;
candidate.resulting_score = state->GetScore() + score_delta;
candidates.emplace_back(candidate);
}
}
}
// Sort the vector of all possible transitions and scores.
const auto comparator = [](const Transition &a, const Transition &b) {
return a.resulting_score > b.resulting_score;
};
std::stable_sort(candidates.begin(), candidates.end(), comparator);
// Apply the top transitions, up to a maximum of 'max_size_'.
std::vector<std::unique_ptr<T>> new_beam;
std::vector<int> previous_beam_indices(max_size_, -1);
const int beam_size =
std::min(max_size_, static_cast<int>(candidates.size()));
VLOG(2) << "Previous beam size = " << beam_.size();
VLOG(2) << "New beam size = " << beam_size;
VLOG(2) << "Maximum beam size = " << max_size_;
for (int i = 0; i < beam_size; ++i) {
// Get the source of the i'th transition.
const auto &transition = candidates[i];
VLOG(2) << "Taking transition with score: "
<< transition.resulting_score
<< " and action: " << transition.action;
VLOG(2) << "transition.source_idx = " << transition.source_idx;
const auto &source = beam_[transition.source_idx];
// Put the new transition on the new state beam.
auto new_state = source->Clone();
perform_transition_(new_state.get(), transition.action);
new_state->SetScore(transition.resulting_score);
new_state->SetBeamIndex(i);
previous_beam_indices.at(i) = transition.source_idx;
new_beam.emplace_back(std::move(new_state));
} }
beam_ = std::move(new_beam);
beam_index_history_.emplace_back(previous_beam_indices);
} }
++num_steps_; ++num_steps_;
return true;
} }
// Advances the Beam from the state oracles. // Advances the Beam from the state oracles.
...@@ -182,7 +120,10 @@ class Beam { ...@@ -182,7 +120,10 @@ class Beam {
for (int i = 0; i < beam_.size(); ++i) { for (int i = 0; i < beam_.size(); ++i) {
previous_beam_indices.at(i) = i; previous_beam_indices.at(i) = i;
if (is_final_(beam_[i].get())) continue; if (is_final_(beam_[i].get())) continue;
const auto oracle_label = oracle_function_(beam_[i].get());
// There will always be at least one oracular transition, and taking the
// first returned transition is never worse than any other option.
const int oracle_label = oracle_function_(beam_[i].get()).at(0);
VLOG(2) << "AdvanceFromOracle beam_index:" << i VLOG(2) << "AdvanceFromOracle beam_index:" << i
<< " oracle_label:" << oracle_label; << " oracle_label:" << oracle_label;
perform_transition_(beam_[i].get(), oracle_label); perform_transition_(beam_[i].get(), oracle_label);
...@@ -312,19 +253,180 @@ class Beam { ...@@ -312,19 +253,180 @@ class Beam {
// Returns the current size of the beam. // Returns the current size of the beam.
const int size() const { return beam_.size(); } const int size() const { return beam_.size(); }
// Returns true if at least one of the states in the beam is gold.
bool ContainsGold() {
if (!track_gold_) {
return false;
}
for (const auto &state : beam_) {
if (state->IsGold()) {
return true;
}
}
return false;
}
private: private:
// Associates an action taken on an index into current_state_ with a score. friend void BM_FastAdvance(int num_iters, int num_transitions);
friend void BM_BeamAdvance(int num_iters, int num_transitions,
int max_beam_size);
// Associates an action taken with its source index.
struct Transition { struct Transition {
// The index of the source item. // The index of the source item.
int source_idx; int source_idx;
// The index of the action being taken. // The index of the action being taken.
int action; int action;
// The score of the full derivation.
double resulting_score;
}; };
// In the case where beam size is 1, we can advance by simply finding the
// highest score and advancing the beam state in place.
bool FastAdvanceFromPrediction(const float *transition_matrix,
int num_actions) {
CHECK_EQ(1, max_size_)
<< "Using fast advance on invalid beam. This should never happen.";
VLOG(2) << "Beam size is 1. Using fast beam path.";
constexpr int kNoActionChosen = -1;
int best_action = kNoActionChosen;
float best_score = -INFINITY;
auto &state = beam_[0];
for (int action_idx = 0; action_idx < num_actions; ++action_idx) {
if (std::isnan(transition_matrix[action_idx])) {
LOG(ERROR) << "Found a NaN in the transition matrix! Unable to "
"continue. Num actions: "
<< num_actions << " index: " << action_idx;
return false;
}
if (is_allowed_(state.get(), action_idx) &&
transition_matrix[action_idx] > best_score) {
best_score = transition_matrix[action_idx];
best_action = action_idx;
}
}
if (best_action == kNoActionChosen) {
LOG(ERROR) << "No action was chosen! Unable to continue. Num actions: "
<< num_actions << " score[0]: " << transition_matrix[0];
return false;
}
bool is_gold = false;
if (track_gold_ && state->IsGold()) {
for (const auto &gold_transition : oracle_function_(state.get())) {
VLOG(3) << "Examining gold transition " << gold_transition
<< " for source index 1";
if (gold_transition == best_action) {
is_gold = true;
break;
}
}
}
perform_transition_(state.get(), best_action);
const float new_score = state->GetScore() + best_score;
state->SetScore(new_score);
state->SetBeamIndex(0);
state->SetGold(is_gold);
return true;
}
// In case the beam size is greater than 1, we need to advance using
// standard beam search.
bool BeamAdvanceFromPrediction(const float *transition_matrix,
int matrix_length, int num_actions) {
VLOG(2) << "Beam size is " << max_size_ << ". Using standard beam search.";
// Keep the multimap sorted high to low. The sort order for
// identical keys is stable.
std::multimap<float, Transition, std::greater<float>> candidates;
float threshold = -INFINITY;
// Iterate through all beams, examining all actions for each beam.
for (int beam_idx = 0; beam_idx < beam_.size(); ++beam_idx) {
const auto &state = beam_[beam_idx];
const float score = state->GetScore();
for (int action_idx = 0; action_idx < num_actions; ++action_idx) {
if (is_allowed_(state.get(), action_idx)) {
// The matrix is laid out by beam index, with a linear set of
// actions for that index - so beam N's actions start at [nr. of
// actions]*[N].
const int matrix_idx = action_idx + beam_idx * num_actions;
CHECK_LT(matrix_idx, matrix_length) << "Matrix index out of bounds!";
const float resulting_score = score + transition_matrix[matrix_idx];
if (std::isnan(resulting_score)) {
LOG(ERROR) << "Resulting score was a NaN! Unable to continue. Num "
"actions: "
<< num_actions << " action index " << action_idx;
return false;
}
if (candidates.size() == max_size_) {
// If the new score is lower than the bottom of the beam, move on.
if (resulting_score < threshold) {
continue;
}
// Otherwise, remove the bottom of the beam, making space
// for the new candidate.
candidates.erase(std::prev(candidates.end()));
}
// Add the new candidate, and update the threshold score.
const Transition candidate{beam_idx, action_idx};
candidates.emplace(resulting_score, candidate);
threshold = candidates.rbegin()->first;
}
}
}
// Apply the top transitions, up to a maximum of 'max_size_'.
std::vector<std::unique_ptr<T>> new_beam;
std::vector<int> previous_beam_indices(max_size_, -1);
const int beam_size = candidates.size();
new_beam.reserve(max_size_);
VLOG(2) << "Previous beam size = " << beam_.size();
VLOG(2) << "New beam size = " << beam_size;
VLOG(2) << "Maximum beam size = " << max_size_;
auto candidate_iterator = candidates.cbegin();
for (int i = 0; i < beam_size; ++i) {
// Get the score and source of the i'th transition.
const float resulting_score = candidate_iterator->first;
const auto &transition = candidate_iterator->second;
++candidate_iterator;
VLOG(2) << "Taking transition with score: " << resulting_score
<< " and action: " << transition.action;
VLOG(2) << "transition.source_idx = " << transition.source_idx;
const auto &source = beam_[transition.source_idx];
// Determine if the transition being taken will result in a gold state.
bool is_gold = false;
if (track_gold_ && source->IsGold()) {
for (const auto &gold_transition : oracle_function_(source.get())) {
VLOG(3) << "Examining gold transition " << gold_transition
<< " for source index " << transition.source_idx;
if (gold_transition == transition.action) {
VLOG(2) << "State from index " << transition.source_idx
<< " is gold.";
is_gold = true;
break;
}
}
}
VLOG(2) << "Gold examination complete for source index "
<< transition.source_idx;
// Put the new transition on the new state beam.
auto new_state = source->Clone();
perform_transition_(new_state.get(), transition.action);
new_state->SetScore(resulting_score);
new_state->SetBeamIndex(i);
new_state->SetGold(is_gold);
previous_beam_indices.at(i) = transition.source_idx;
new_beam.emplace_back(std::move(new_state));
}
beam_ = std::move(new_beam);
beam_index_history_.emplace_back(previous_beam_indices);
return true;
}
// The maximum beam size. // The maximum beam size.
int max_size_; int max_size_;
...@@ -341,7 +443,7 @@ class Beam { ...@@ -341,7 +443,7 @@ class Beam {
std::function<void(T *, int)> perform_transition_; std::function<void(T *, int)> perform_transition_;
// Function to provide the oracle action for a given state. // Function to provide the oracle action for a given state.
std::function<int(T *)> oracle_function_; std::function<vector<int>(T *)> oracle_function_;
// The history of the states in this beam. The vector indexes across steps. // The history of the states in this beam. The vector indexes across steps.
// For every step, there is a vector in the vector. This inner vector denotes // For every step, there is a vector in the vector. This inner vector denotes
...@@ -355,9 +457,12 @@ class Beam { ...@@ -355,9 +457,12 @@ class Beam {
// The number of steps taken so far. // The number of steps taken so far.
int num_steps_; int num_steps_;
// Whether to track golden states.
bool track_gold_;
}; };
} // namespace dragnn } // namespace dragnn
} // namespace syntaxnet } // namespace syntaxnet
#endif // NLP_SAFT_OPENSOURCE_DRAGNN_CORE_BEAM_H_ #endif // DRAGNN_CORE_BEAM_H_
This diff is collapsed.
...@@ -13,12 +13,17 @@ ...@@ -13,12 +13,17 @@
// limitations under the License. // limitations under the License.
// ============================================================================= // =============================================================================
#ifndef NLP_SAFT_OPENSOURCE_DRAGNN_CORE_COMPONENT_REGISTRY_H_ #ifndef DRAGNN_CORE_COMPONENT_REGISTRY_H_
#define NLP_SAFT_OPENSOURCE_DRAGNN_CORE_COMPONENT_REGISTRY_H_ #define DRAGNN_CORE_COMPONENT_REGISTRY_H_
#include "dragnn/core/interfaces/component.h" #include "dragnn/core/interfaces/component.h"
#include "syntaxnet/registry.h" #include "syntaxnet/registry.h"
namespace syntaxnet {
// Class registry for DRAGNN components.
DECLARE_SYNTAXNET_CLASS_REGISTRY("DRAGNN Component", dragnn::Component);
} // namespace syntaxnet
// Macro to add a component to the registry. This macro associates a class with // Macro to add a component to the registry. This macro associates a class with
// its class name as a string, so FooComponent would be associated with the // its class name as a string, so FooComponent would be associated with the
// string "FooComponent". // string "FooComponent".
...@@ -26,4 +31,4 @@ ...@@ -26,4 +31,4 @@
REGISTER_SYNTAXNET_CLASS_COMPONENT(syntaxnet::dragnn::Component, #component, \ REGISTER_SYNTAXNET_CLASS_COMPONENT(syntaxnet::dragnn::Component, #component, \
component) component)
#endif // NLP_SAFT_OPENSOURCE_DRAGNN_CORE_COMPONENT_REGISTRY_H_ #endif // DRAGNN_CORE_COMPONENT_REGISTRY_H_
...@@ -13,13 +13,14 @@ ...@@ -13,13 +13,14 @@
// limitations under the License. // limitations under the License.
// ============================================================================= // =============================================================================
#ifndef NLP_SAFT_OPENSOURCE_DRAGNN_CORE_COMPUTE_SESSION_H_ #ifndef DRAGNN_CORE_COMPUTE_SESSION_H_
#define NLP_SAFT_OPENSOURCE_DRAGNN_CORE_COMPUTE_SESSION_H_ #define DRAGNN_CORE_COMPUTE_SESSION_H_
#include <string> #include <string>
#include "dragnn/components/util/bulk_feature_extractor.h" #include "dragnn/components/util/bulk_feature_extractor.h"
#include "dragnn/core/index_translator.h" #include "dragnn/core/index_translator.h"
#include "dragnn/core/input_batch_cache.h"
#include "dragnn/core/interfaces/component.h" #include "dragnn/core/interfaces/component.h"
#include "dragnn/protos/spec.pb.h" #include "dragnn/protos/spec.pb.h"
#include "dragnn/protos/trace.pb.h" #include "dragnn/protos/trace.pb.h"
...@@ -64,10 +65,11 @@ class ComputeSession { ...@@ -64,10 +65,11 @@ class ComputeSession {
// Advance the given component using the component's oracle. // Advance the given component using the component's oracle.
virtual void AdvanceFromOracle(const string &component_name) = 0; virtual void AdvanceFromOracle(const string &component_name) = 0;
// Advance the given component using the given score matrix. // Advance the given component using the given score matrix, which is
virtual void AdvanceFromPrediction(const string &component_name, // |num_items| x |num_actions|.
const float score_matrix[], virtual bool AdvanceFromPrediction(const string &component_name,
int score_matrix_length) = 0; const float *score_matrix, int num_items,
int num_actions) = 0;
// Get the input features for the given component and channel. This passes // Get the input features for the given component and channel. This passes
// through to the relevant Component's GetFixedFeatures() call. // through to the relevant Component's GetFixedFeatures() call.
...@@ -84,6 +86,15 @@ class ComputeSession { ...@@ -84,6 +86,15 @@ class ComputeSession {
virtual int BulkGetInputFeatures(const string &component_name, virtual int BulkGetInputFeatures(const string &component_name,
const BulkFeatureExtractor &extractor) = 0; const BulkFeatureExtractor &extractor) = 0;
// Directly computes the embedding matrix for all channels, advancing the
// component via the oracle until it is terminal. This call takes a vector
// of float embedding matrices, one per channel, in channel order.
virtual void BulkEmbedFixedFeatures(
const string &component_name, int batch_size_padding,
int num_steps_padding, int output_array_size,
const vector<const float *> &per_channel_embeddings,
float *embedding_output) = 0;
// Get the input features for the given component and channel. This function // Get the input features for the given component and channel. This function
// can return empty LinkFeatures protos, which represent unused padding slots // can return empty LinkFeatures protos, which represent unused padding slots
// in the output weight tensor. // in the output weight tensor.
...@@ -111,6 +122,10 @@ class ComputeSession { ...@@ -111,6 +122,10 @@ class ComputeSession {
// Provides the ComputeSession with a batch of data to compute. // Provides the ComputeSession with a batch of data to compute.
virtual void SetInputData(const std::vector<string> &data) = 0; virtual void SetInputData(const std::vector<string> &data) = 0;
// Like SetInputData(), but accepts an InputBatchCache directly, potentially
// bypassing de-serialization.
virtual void SetInputBatchCache(std::unique_ptr<InputBatchCache> batch) = 0;
// Resets all components owned by this ComputeSession. // Resets all components owned by this ComputeSession.
virtual void ResetSession() = 0; virtual void ResetSession() = 0;
...@@ -127,9 +142,14 @@ class ComputeSession { ...@@ -127,9 +142,14 @@ class ComputeSession {
// validate correct construction of translators in tests. // validate correct construction of translators in tests.
virtual const std::vector<const IndexTranslator *> Translators( virtual const std::vector<const IndexTranslator *> Translators(
const string &component_name) const = 0; const string &component_name) const = 0;
// Get a given component. CHECK-fail if the component's IsReady method
// returns false.
virtual Component *GetReadiedComponent(
const string &component_name) const = 0;
}; };
} // namespace dragnn } // namespace dragnn
} // namespace syntaxnet } // namespace syntaxnet
#endif // NLP_SAFT_OPENSOURCE_DRAGNN_CORE_COMPUTE_SESSION_H_ #endif // DRAGNN_CORE_COMPUTE_SESSION_H_
...@@ -161,11 +161,11 @@ void ComputeSessionImpl::AdvanceFromOracle(const string &component_name) { ...@@ -161,11 +161,11 @@ void ComputeSessionImpl::AdvanceFromOracle(const string &component_name) {
GetReadiedComponent(component_name)->AdvanceFromOracle(); GetReadiedComponent(component_name)->AdvanceFromOracle();
} }
void ComputeSessionImpl::AdvanceFromPrediction(const string &component_name, bool ComputeSessionImpl::AdvanceFromPrediction(const string &component_name,
const float score_matrix[], const float *score_matrix,
int score_matrix_length) { int num_items, int num_actions) {
GetReadiedComponent(component_name) return GetReadiedComponent(component_name)
->AdvanceFromPrediction(score_matrix, score_matrix_length); ->AdvanceFromPrediction(score_matrix, num_items, num_actions);
} }
int ComputeSessionImpl::GetInputFeatures( int ComputeSessionImpl::GetInputFeatures(
...@@ -182,6 +182,16 @@ int ComputeSessionImpl::BulkGetInputFeatures( ...@@ -182,6 +182,16 @@ int ComputeSessionImpl::BulkGetInputFeatures(
return GetReadiedComponent(component_name)->BulkGetFixedFeatures(extractor); return GetReadiedComponent(component_name)->BulkGetFixedFeatures(extractor);
} }
void ComputeSessionImpl::BulkEmbedFixedFeatures(
const string &component_name, int batch_size_padding, int num_steps_padding,
int output_array_size, const vector<const float *> &per_channel_embeddings,
float *embedding_output) {
return GetReadiedComponent(component_name)
->BulkEmbedFixedFeatures(batch_size_padding, num_steps_padding,
output_array_size, per_channel_embeddings,
embedding_output);
}
std::vector<LinkFeatures> ComputeSessionImpl::GetTranslatedLinkFeatures( std::vector<LinkFeatures> ComputeSessionImpl::GetTranslatedLinkFeatures(
const string &component_name, int channel_id) { const string &component_name, int channel_id) {
auto *component = GetReadiedComponent(component_name); auto *component = GetReadiedComponent(component_name);
...@@ -288,6 +298,11 @@ void ComputeSessionImpl::SetInputData(const std::vector<string> &data) { ...@@ -288,6 +298,11 @@ void ComputeSessionImpl::SetInputData(const std::vector<string> &data) {
input_data_.reset(new InputBatchCache(data)); input_data_.reset(new InputBatchCache(data));
} }
void ComputeSessionImpl::SetInputBatchCache(
std::unique_ptr<InputBatchCache> batch) {
input_data_ = std::move(batch);
}
void ComputeSessionImpl::ResetSession() { void ComputeSessionImpl::ResetSession() {
// Reset all component states. // Reset all component states.
for (auto &component_pair : components_) { for (auto &component_pair : components_) {
...@@ -308,6 +323,7 @@ const std::vector<const IndexTranslator *> ComputeSessionImpl::Translators( ...@@ -308,6 +323,7 @@ const std::vector<const IndexTranslator *> ComputeSessionImpl::Translators(
const string &component_name) const { const string &component_name) const {
auto translators = GetTranslators(component_name); auto translators = GetTranslators(component_name);
std::vector<const IndexTranslator *> const_translators; std::vector<const IndexTranslator *> const_translators;
const_translators.reserve(translators.size());
for (const auto &translator : translators) { for (const auto &translator : translators) {
const_translators.push_back(translator); const_translators.push_back(translator);
} }
......
...@@ -13,8 +13,8 @@ ...@@ -13,8 +13,8 @@
// limitations under the License. // limitations under the License.
// ============================================================================= // =============================================================================
#ifndef NLP_SAFT_OPENSOURCE_DRAGNN_CORE_COMPUTE_SESSION_IMPL_H_ #ifndef DRAGNN_CORE_COMPUTE_SESSION_IMPL_H_
#define NLP_SAFT_OPENSOURCE_DRAGNN_CORE_COMPUTE_SESSION_IMPL_H_ #define DRAGNN_CORE_COMPUTE_SESSION_IMPL_H_
#include <memory> #include <memory>
...@@ -55,9 +55,9 @@ class ComputeSessionImpl : public ComputeSession { ...@@ -55,9 +55,9 @@ class ComputeSessionImpl : public ComputeSession {
void AdvanceFromOracle(const string &component_name) override; void AdvanceFromOracle(const string &component_name) override;
void AdvanceFromPrediction(const string &component_name, bool AdvanceFromPrediction(const string &component_name,
const float score_matrix[], const float *score_matrix, int num_items,
int score_matrix_length) override; int num_actions) override;
int GetInputFeatures(const string &component_name, int GetInputFeatures(const string &component_name,
std::function<int32 *(int)> allocate_indices, std::function<int32 *(int)> allocate_indices,
...@@ -68,6 +68,12 @@ class ComputeSessionImpl : public ComputeSession { ...@@ -68,6 +68,12 @@ class ComputeSessionImpl : public ComputeSession {
int BulkGetInputFeatures(const string &component_name, int BulkGetInputFeatures(const string &component_name,
const BulkFeatureExtractor &extractor) override; const BulkFeatureExtractor &extractor) override;
void BulkEmbedFixedFeatures(
const string &component_name, int batch_size_padding,
int num_steps_padding, int output_array_size,
const vector<const float *> &per_channel_embeddings,
float *embedding_output) override;
std::vector<LinkFeatures> GetTranslatedLinkFeatures( std::vector<LinkFeatures> GetTranslatedLinkFeatures(
const string &component_name, int channel_id) override; const string &component_name, int channel_id) override;
...@@ -84,6 +90,8 @@ class ComputeSessionImpl : public ComputeSession { ...@@ -84,6 +90,8 @@ class ComputeSessionImpl : public ComputeSession {
void SetInputData(const std::vector<string> &data) override; void SetInputData(const std::vector<string> &data) override;
void SetInputBatchCache(std::unique_ptr<InputBatchCache> batch) override;
void ResetSession() override; void ResetSession() override;
void SetTracing(bool tracing_on) override; void SetTracing(bool tracing_on) override;
...@@ -95,14 +103,14 @@ class ComputeSessionImpl : public ComputeSession { ...@@ -95,14 +103,14 @@ class ComputeSessionImpl : public ComputeSession {
const std::vector<const IndexTranslator *> Translators( const std::vector<const IndexTranslator *> Translators(
const string &component_name) const override; const string &component_name) const override;
// Get a given component. CHECK-fail if the component's IsReady method
// returns false.
Component *GetReadiedComponent(const string &component_name) const override;
private: private:
// Get a given component. Fails if the component is not found. // Get a given component. Fails if the component is not found.
Component *GetComponent(const string &component_name) const; Component *GetComponent(const string &component_name) const;
// Get a given component. CHECK-fail if the component's IsReady method
// returns false.
Component *GetReadiedComponent(const string &component_name) const;
// Get the index translators for the given component. // Get the index translators for the given component.
const std::vector<IndexTranslator *> &GetTranslators( const std::vector<IndexTranslator *> &GetTranslators(
const string &component_name) const; const string &component_name) const;
...@@ -154,4 +162,4 @@ class ComputeSessionImpl : public ComputeSession { ...@@ -154,4 +162,4 @@ class ComputeSessionImpl : public ComputeSession {
} // namespace dragnn } // namespace dragnn
} // namespace syntaxnet } // namespace syntaxnet
#endif // NLP_SAFT_OPENSOURCE_DRAGNN_CORE_COMPUTE_SESSION_IMPL_H_ #endif // DRAGNN_CORE_COMPUTE_SESSION_IMPL_H_
...@@ -22,7 +22,9 @@ ...@@ -22,7 +22,9 @@
#include "dragnn/core/component_registry.h" #include "dragnn/core/component_registry.h"
#include "dragnn/core/compute_session.h" #include "dragnn/core/compute_session.h"
#include "dragnn/core/compute_session_pool.h" #include "dragnn/core/compute_session_pool.h"
#include "dragnn/core/input_batch_cache.h"
#include "dragnn/core/interfaces/component.h" #include "dragnn/core/interfaces/component.h"
#include "dragnn/core/interfaces/input_batch.h"
#include "dragnn/core/test/generic.h" #include "dragnn/core/test/generic.h"
#include "dragnn/core/test/mock_component.h" #include "dragnn/core/test/mock_component.h"
#include "dragnn/core/test/mock_transition_state.h" #include "dragnn/core/test/mock_transition_state.h"
...@@ -65,8 +67,10 @@ class TestComponentType1 : public Component { ...@@ -65,8 +67,10 @@ class TestComponentType1 : public Component {
int GetSourceBeamIndex(int current_index, int batch) const override { int GetSourceBeamIndex(int current_index, int batch) const override {
return 0; return 0;
} }
void AdvanceFromPrediction(const float transition_matrix[], bool AdvanceFromPrediction(const float *score_matrix, int num_items,
int matrix_length) override {} int num_actions) override {
return true;
}
void AdvanceFromOracle() override {} void AdvanceFromOracle() override {}
bool IsTerminal() const override { return true; } bool IsTerminal() const override { return true; }
std::function<int(int, int, int)> GetStepLookupFunction( std::function<int(int, int, int)> GetStepLookupFunction(
...@@ -83,6 +87,10 @@ class TestComponentType1 : public Component { ...@@ -83,6 +87,10 @@ class TestComponentType1 : public Component {
int channel_id) const override { int channel_id) const override {
return 0; return 0;
} }
void BulkEmbedFixedFeatures(
int batch_size_padding, int num_steps_padding, int embedding_size,
const vector<const float *> &per_channel_embeddings,
float *embedding_output) override {}
int BulkGetFixedFeatures(const BulkFeatureExtractor &extractor) override { int BulkGetFixedFeatures(const BulkFeatureExtractor &extractor) override {
return 0; return 0;
} }
...@@ -133,8 +141,10 @@ class TestComponentType2 : public Component { ...@@ -133,8 +141,10 @@ class TestComponentType2 : public Component {
int GetSourceBeamIndex(int current_index, int batch) const override { int GetSourceBeamIndex(int current_index, int batch) const override {
return 0; return 0;
} }
void AdvanceFromPrediction(const float transition_matrix[], bool AdvanceFromPrediction(const float *score_matrix, int num_items,
int matrix_length) override {} int num_actions) override {
return true;
}
void AdvanceFromOracle() override {} void AdvanceFromOracle() override {}
bool IsTerminal() const override { return true; } bool IsTerminal() const override { return true; }
std::function<int(int, int, int)> GetStepLookupFunction( std::function<int(int, int, int)> GetStepLookupFunction(
...@@ -151,6 +161,10 @@ class TestComponentType2 : public Component { ...@@ -151,6 +161,10 @@ class TestComponentType2 : public Component {
int channel_id) const override { int channel_id) const override {
return 0; return 0;
} }
void BulkEmbedFixedFeatures(
int batch_size_padding, int num_steps_padding, int embedding_size,
const vector<const float *> &per_channel_embeddings,
float *embedding_output) override {}
int BulkGetFixedFeatures(const BulkFeatureExtractor &extractor) override { int BulkGetFixedFeatures(const BulkFeatureExtractor &extractor) override {
return 0; return 0;
} }
...@@ -201,8 +215,14 @@ class UnreadyComponent : public Component { ...@@ -201,8 +215,14 @@ class UnreadyComponent : public Component {
int GetSourceBeamIndex(int current_index, int batch) const override { int GetSourceBeamIndex(int current_index, int batch) const override {
return 0; return 0;
} }
void AdvanceFromPrediction(const float transition_matrix[], bool AdvanceFromPrediction(const float *score_matrix, int num_items,
int matrix_length) override {} int num_actions) override {
return true;
}
void BulkEmbedFixedFeatures(
int batch_size_padding, int num_steps_padding, int embedding_size,
const vector<const float *> &per_channel_embeddings,
float *embedding_output) override {}
void AdvanceFromOracle() override {} void AdvanceFromOracle() override {}
bool IsTerminal() const override { return false; } bool IsTerminal() const override { return false; }
std::function<int(int, int, int)> GetStepLookupFunction( std::function<int(int, int, int)> GetStepLookupFunction(
...@@ -254,6 +274,18 @@ class ComputeSessionImplTestPoolAccessor { ...@@ -254,6 +274,18 @@ class ComputeSessionImplTestPoolAccessor {
} }
}; };
// An InputBatch that uses the serialized data directly.
class IdentityBatch : public InputBatch {
public:
// Implements InputBatch.
void SetData(const std::vector<string> &data) override { data_ = data; }
int GetSize() const override { return data_.size(); }
const std::vector<string> GetSerializedData() const override { return data_; }
private:
std::vector<string> data_; // the batch data
};
// ***************************************************************************** // *****************************************************************************
// Tests begin here. // Tests begin here.
// ***************************************************************************** // *****************************************************************************
...@@ -739,7 +771,7 @@ TEST(ComputeSessionImplTest, InitializesComponentWithSource) { ...@@ -739,7 +771,7 @@ TEST(ComputeSessionImplTest, InitializesComponentWithSource) {
EXPECT_CALL(*mock_components["component_one"], GetBeam()) EXPECT_CALL(*mock_components["component_one"], GetBeam())
.WillOnce(Return(beam)); .WillOnce(Return(beam));
// Expect that the second component will recieve that beam. // Expect that the second component will receive that beam.
EXPECT_CALL(*mock_components["component_two"], EXPECT_CALL(*mock_components["component_two"],
InitializeData(beam, kMaxBeamSize, NotNull())); InitializeData(beam, kMaxBeamSize, NotNull()));
...@@ -899,7 +931,7 @@ TEST(ComputeSessionImplTest, SetTracingPropagatesToAllComponents) { ...@@ -899,7 +931,7 @@ TEST(ComputeSessionImplTest, SetTracingPropagatesToAllComponents) {
EXPECT_CALL(*mock_components["component_one"], GetBeam()) EXPECT_CALL(*mock_components["component_one"], GetBeam())
.WillOnce(Return(beam)); .WillOnce(Return(beam));
// Expect that the second component will recieve that beam, and then its // Expect that the second component will receive that beam, and then its
// tracing will be initialized. // tracing will be initialized.
EXPECT_CALL(*mock_components["component_two"], EXPECT_CALL(*mock_components["component_two"],
InitializeData(beam, kMaxBeamSize, NotNull())); InitializeData(beam, kMaxBeamSize, NotNull()));
...@@ -1084,12 +1116,12 @@ TEST(ComputeSessionImplTest, InterfacePassesThrough) { ...@@ -1084,12 +1116,12 @@ TEST(ComputeSessionImplTest, InterfacePassesThrough) {
session->AdvanceFromOracle("component_one"); session->AdvanceFromOracle("component_one");
// AdvanceFromPrediction() // AdvanceFromPrediction()
constexpr int kScoreMatrixLength = 3; const int kNumActions = 1;
const float score_matrix[kScoreMatrixLength] = {1.0, 2.3, 4.5}; const float score_matrix[] = {1.0, 2.3, 4.5};
EXPECT_CALL(*mock_components["component_one"], EXPECT_CALL(*mock_components["component_one"],
AdvanceFromPrediction(score_matrix, kScoreMatrixLength)); AdvanceFromPrediction(score_matrix, batch_size, kNumActions));
session->AdvanceFromPrediction("component_one", score_matrix, session->AdvanceFromPrediction("component_one", score_matrix, batch_size,
kScoreMatrixLength); kNumActions);
// GetFixedFeatures // GetFixedFeatures
auto allocate_indices = [](int size) -> int32 * { return nullptr; }; auto allocate_indices = [](int size) -> int32 * { return nullptr; };
...@@ -1109,6 +1141,11 @@ TEST(ComputeSessionImplTest, InterfacePassesThrough) { ...@@ -1109,6 +1141,11 @@ TEST(ComputeSessionImplTest, InterfacePassesThrough) {
.WillOnce(Return(0)); .WillOnce(Return(0));
EXPECT_EQ(0, session->BulkGetInputFeatures("component_one", extractor)); EXPECT_EQ(0, session->BulkGetInputFeatures("component_one", extractor));
// BulkEmbedFixedFeatures
EXPECT_CALL(*mock_components["component_one"],
BulkEmbedFixedFeatures(1, 2, 3, _, _));
session->BulkEmbedFixedFeatures("component_one", 1, 2, 3, {nullptr}, nullptr);
// EmitOracleLabels() // EmitOracleLabels()
std::vector<std::vector<int>> oracle_labels = {{0, 1}, {2, 3}}; std::vector<std::vector<int>> oracle_labels = {{0, 1}, {2, 3}};
EXPECT_CALL(*mock_components["component_one"], GetOracleLabels()) EXPECT_CALL(*mock_components["component_one"], GetOracleLabels())
...@@ -1154,7 +1191,7 @@ TEST(ComputeSessionImplTest, InterfaceRequiresReady) { ...@@ -1154,7 +1191,7 @@ TEST(ComputeSessionImplTest, InterfaceRequiresReady) {
constexpr int kScoreMatrixLength = 3; constexpr int kScoreMatrixLength = 3;
const float score_matrix[kScoreMatrixLength] = {1.0, 2.3, 4.5}; const float score_matrix[kScoreMatrixLength] = {1.0, 2.3, 4.5};
EXPECT_DEATH(session->AdvanceFromPrediction("component_one", score_matrix, EXPECT_DEATH(session->AdvanceFromPrediction("component_one", score_matrix,
kScoreMatrixLength), kScoreMatrixLength, 1),
"without first initializing it"); "without first initializing it");
constexpr int kArbitraryChannelId = 3; constexpr int kArbitraryChannelId = 3;
EXPECT_DEATH(session->GetInputFeatures("component_one", nullptr, nullptr, EXPECT_DEATH(session->GetInputFeatures("component_one", nullptr, nullptr,
...@@ -1163,10 +1200,32 @@ TEST(ComputeSessionImplTest, InterfaceRequiresReady) { ...@@ -1163,10 +1200,32 @@ TEST(ComputeSessionImplTest, InterfaceRequiresReady) {
BulkFeatureExtractor extractor(nullptr, nullptr, nullptr, false, 0, 0); BulkFeatureExtractor extractor(nullptr, nullptr, nullptr, false, 0, 0);
EXPECT_DEATH(session->BulkGetInputFeatures("component_one", extractor), EXPECT_DEATH(session->BulkGetInputFeatures("component_one", extractor),
"without first initializing it"); "without first initializing it");
EXPECT_DEATH(session->BulkEmbedFixedFeatures("component_one", 0, 0, 0,
{nullptr}, nullptr),
"without first initializing it");
EXPECT_DEATH( EXPECT_DEATH(
session->GetTranslatedLinkFeatures("component_one", kArbitraryChannelId), session->GetTranslatedLinkFeatures("component_one", kArbitraryChannelId),
"without first initializing it"); "without first initializing it");
} }
TEST(ComputeSessionImplTest, SetInputBatchCache) {
// Use empty protos since we won't interact with components.
MasterSpec spec;
GridPoint hyperparams;
ComputeSessionPool pool(spec, hyperparams);
auto session = pool.GetSession();
// Initialize a cached IdentityBatch.
const std::vector<string> data = {"foo", "bar", "baz"};
std::unique_ptr<InputBatchCache> input_batch_cache(new InputBatchCache(data));
input_batch_cache->GetAs<IdentityBatch>();
// Inject the cache into the session.
session->SetInputBatchCache(std::move(input_batch_cache));
// Check that the injected batch can be retrieved.
EXPECT_EQ(session->GetSerializedPredictions(), data);
}
} // namespace dragnn } // namespace dragnn
} // namespace syntaxnet } // namespace syntaxnet
...@@ -13,8 +13,8 @@ ...@@ -13,8 +13,8 @@
// limitations under the License. // limitations under the License.
// ============================================================================= // =============================================================================
#ifndef NLP_SAFT_OPENSOURCE_DRAGNN_CORE_COMPUTE_SESSION_POOL_H_ #ifndef DRAGNN_CORE_COMPUTE_SESSION_POOL_H_
#define NLP_SAFT_OPENSOURCE_DRAGNN_CORE_COMPUTE_SESSION_POOL_H_ #define DRAGNN_CORE_COMPUTE_SESSION_POOL_H_
#include <memory> #include <memory>
...@@ -29,14 +29,14 @@ namespace dragnn { ...@@ -29,14 +29,14 @@ namespace dragnn {
class ComputeSessionPool { class ComputeSessionPool {
public: public:
// Create a ComputeSessionPool that creates ComputeSessions for the given // Creates a ComputeSessionPool that creates ComputeSessions for the given
// MasterSpec and hyperparameters. // MasterSpec and hyperparameters.
ComputeSessionPool(const MasterSpec &master_spec, ComputeSessionPool(const MasterSpec &master_spec,
const GridPoint &hyperparams); const GridPoint &hyperparams);
virtual ~ComputeSessionPool(); virtual ~ComputeSessionPool();
// Get a ComputeSession. This function will attempt to use an already-created // Gets a ComputeSession. This function will attempt to use an already-created
// ComputeSession, but if none are available a new one will be created. // ComputeSession, but if none are available a new one will be created.
std::unique_ptr<ComputeSession> GetSession(); std::unique_ptr<ComputeSession> GetSession();
...@@ -49,6 +49,12 @@ class ComputeSessionPool { ...@@ -49,6 +49,12 @@ class ComputeSessionPool {
return num_unique_sessions_ - sessions_.size(); return num_unique_sessions_ - sessions_.size();
} }
// Returns the number of unique sessions that have been created.
int num_unique_sessions() { return num_unique_sessions_; }
// Returns a reference to the underlying spec for this pool.
const MasterSpec &GetSpec() const { return master_spec_; }
private: private:
friend class ComputeSessionImplTestPoolAccessor; friend class ComputeSessionImplTestPoolAccessor;
friend class ComputeSessionPoolTestPoolAccessor; friend class ComputeSessionPoolTestPoolAccessor;
...@@ -99,4 +105,4 @@ class ComputeSessionPool { ...@@ -99,4 +105,4 @@ class ComputeSessionPool {
} // namespace dragnn } // namespace dragnn
} // namespace syntaxnet } // namespace syntaxnet
#endif // NLP_SAFT_OPENSOURCE_DRAGNN_CORE_COMPUTE_SESSION_POOL_H_ #endif // DRAGNN_CORE_COMPUTE_SESSION_POOL_H_
...@@ -207,6 +207,7 @@ TEST(ComputeSessionPoolTest, SupportsMultithreadedAccess) { ...@@ -207,6 +207,7 @@ TEST(ComputeSessionPoolTest, SupportsMultithreadedAccess) {
std::vector<std::unique_ptr<tensorflow::Thread>> request_threads; std::vector<std::unique_ptr<tensorflow::Thread>> request_threads;
constexpr int kNumThreadsToTest = 100; constexpr int kNumThreadsToTest = 100;
request_threads.reserve(kNumThreadsToTest);
for (int i = 0; i < kNumThreadsToTest; ++i) { for (int i = 0; i < kNumThreadsToTest; ++i) {
request_threads.push_back(std::unique_ptr<tensorflow::Thread>( request_threads.push_back(std::unique_ptr<tensorflow::Thread>(
tensorflow::Env::Default()->StartThread( tensorflow::Env::Default()->StartThread(
......
...@@ -13,8 +13,8 @@ ...@@ -13,8 +13,8 @@
// limitations under the License. // limitations under the License.
// ============================================================================= // =============================================================================
#ifndef NLP_SAFT_OPENSOURCE_DRAGNN_CORE_INDEX_TRANSLATOR_H_ #ifndef DRAGNN_CORE_INDEX_TRANSLATOR_H_
#define NLP_SAFT_OPENSOURCE_DRAGNN_CORE_INDEX_TRANSLATOR_H_ #define DRAGNN_CORE_INDEX_TRANSLATOR_H_
#include <memory> #include <memory>
#include <vector> #include <vector>
...@@ -80,4 +80,4 @@ class IndexTranslator { ...@@ -80,4 +80,4 @@ class IndexTranslator {
} // namespace dragnn } // namespace dragnn
} // namespace syntaxnet } // namespace syntaxnet
#endif // NLP_SAFT_OPENSOURCE_DRAGNN_CORE_INDEX_TRANSLATOR_H_ #endif // DRAGNN_CORE_INDEX_TRANSLATOR_H_
...@@ -13,12 +13,15 @@ ...@@ -13,12 +13,15 @@
// limitations under the License. // limitations under the License.
// ============================================================================= // =============================================================================
#ifndef NLP_SAFT_OPENSOURCE_DRAGNN_CORE_INPUT_BATCH_CACHE_H_ #ifndef DRAGNN_CORE_INPUT_BATCH_CACHE_H_
#define NLP_SAFT_OPENSOURCE_DRAGNN_CORE_INPUT_BATCH_CACHE_H_ #define DRAGNN_CORE_INPUT_BATCH_CACHE_H_
#include <memory> #include <memory>
#include <string> #include <string>
#include <type_traits>
#include <typeindex> #include <typeindex>
#include <typeinfo>
#include <utility>
#include "dragnn/core/interfaces/input_batch.h" #include "dragnn/core/interfaces/input_batch.h"
#include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/logging.h"
...@@ -42,6 +45,18 @@ class InputBatchCache { ...@@ -42,6 +45,18 @@ class InputBatchCache {
explicit InputBatchCache(const std::vector<string> &data) explicit InputBatchCache(const std::vector<string> &data)
: stored_type_(std::type_index(typeid(void))), source_data_(data) {} : stored_type_(std::type_index(typeid(void))), source_data_(data) {}
// Creates a InputBatchCache from the |batch|. InputBatchSubclass must be a
// strict subclass of InputBatch, and |batch| must be non-null. All calls to
// GetAs must match InputBatchSubclass.
template <class InputBatchSubclass>
explicit InputBatchCache(std::unique_ptr<InputBatchSubclass> batch)
: stored_type_(std::type_index(typeid(InputBatchSubclass))),
converted_data_(std::move(batch)) {
static_assert(IsStrictInputBatchSubclass<InputBatchSubclass>(),
"InputBatchCache requires a strict subclass of InputBatch");
CHECK(converted_data_) << "Cannot initialize from a null InputBatch";
}
// Adds a single string to the cache. Only useable before GetAs() has been // Adds a single string to the cache. Only useable before GetAs() has been
// called. // called.
void AddData(const string &data) { void AddData(const string &data) {
...@@ -52,10 +67,14 @@ class InputBatchCache { ...@@ -52,10 +67,14 @@ class InputBatchCache {
} }
// Converts the stored strings into protos and return them in a specific // Converts the stored strings into protos and return them in a specific
// InputBatch subclass. T should always be of type InputBatch. After this // InputBatch subclass. T should always be a strict subclass of InputBatch.
// method is called once, all further calls must be of the same data type. // After this method is called once, all further calls must be of the same
// data type.
template <class T> template <class T>
T *GetAs() { T *GetAs() {
static_assert(
IsStrictInputBatchSubclass<T>(),
"GetAs<T>() requires that T is a strict subclass of InputBatch");
if (!converted_data_) { if (!converted_data_) {
stored_type_ = std::type_index(typeid(T)); stored_type_ = std::type_index(typeid(T));
converted_data_.reset(new T()); converted_data_.reset(new T());
...@@ -69,14 +88,27 @@ class InputBatchCache { ...@@ -69,14 +88,27 @@ class InputBatchCache {
return dynamic_cast<T *>(converted_data_.get()); return dynamic_cast<T *>(converted_data_.get());
} }
// Returns the size of the batch. Requires that GetAs() has been called.
int Size() const {
CHECK(converted_data_) << "Cannot return batch size without data.";
return converted_data_->GetSize();
}
// Returns the serialized representation of the data held in the input batch // Returns the serialized representation of the data held in the input batch
// object within this cache. // object within this cache. Requires that GetAs() has been called.
const std::vector<string> SerializedData() const { const std::vector<string> SerializedData() const {
CHECK(converted_data_) << "Cannot return batch without data."; CHECK(converted_data_) << "Cannot return batch without data.";
return converted_data_->GetSerializedData(); return converted_data_->GetSerializedData();
} }
private: private:
// Returns true if InputBatchSubclass is a strict subclass of InputBatch.
template <class InputBatchSubclass>
static constexpr bool IsStrictInputBatchSubclass() {
return std::is_base_of<InputBatch, InputBatchSubclass>::value &&
!std::is_same<InputBatch, InputBatchSubclass>::value;
}
// The typeid of the stored data. // The typeid of the stored data.
std::type_index stored_type_; std::type_index stored_type_;
...@@ -90,4 +122,4 @@ class InputBatchCache { ...@@ -90,4 +122,4 @@ class InputBatchCache {
} // namespace dragnn } // namespace dragnn
} // namespace syntaxnet } // namespace syntaxnet
#endif // NLP_SAFT_OPENSOURCE_DRAGNN_CORE_INPUT_BATCH_CACHE_H_ #endif // DRAGNN_CORE_INPUT_BATCH_CACHE_H_
...@@ -32,6 +32,8 @@ class StringData : public InputBatch { ...@@ -32,6 +32,8 @@ class StringData : public InputBatch {
} }
} }
int GetSize() const override { return data_.size(); }
const std::vector<string> GetSerializedData() const override { return data_; } const std::vector<string> GetSerializedData() const override { return data_; }
std::vector<string> *data() { return &data_; } std::vector<string> *data() { return &data_; }
...@@ -50,6 +52,8 @@ class DifferentStringData : public InputBatch { ...@@ -50,6 +52,8 @@ class DifferentStringData : public InputBatch {
} }
} }
int GetSize() const override { return data_.size(); }
const std::vector<string> GetSerializedData() const override { return data_; } const std::vector<string> GetSerializedData() const override { return data_; }
std::vector<string> *data() { return &data_; } std::vector<string> *data() { return &data_; }
...@@ -58,6 +62,11 @@ class DifferentStringData : public InputBatch { ...@@ -58,6 +62,11 @@ class DifferentStringData : public InputBatch {
std::vector<string> data_; std::vector<string> data_;
}; };
// Expects that two pointers have the same address.
void ExpectSameAddress(const void *pointer1, const void *pointer2) {
EXPECT_EQ(pointer1, pointer2);
}
TEST(InputBatchCacheTest, ConvertsSingleInput) { TEST(InputBatchCacheTest, ConvertsSingleInput) {
string test_string = "Foo"; string test_string = "Foo";
InputBatchCache generic_set(test_string); InputBatchCache generic_set(test_string);
...@@ -118,5 +127,48 @@ TEST(InputBatchCacheTest, ConvertsAddedInputDiesAfterGetAs) { ...@@ -118,5 +127,48 @@ TEST(InputBatchCacheTest, ConvertsAddedInputDiesAfterGetAs) {
"after the cache has been converted"); "after the cache has been converted");
} }
TEST(InputBatchCacheTest, SerializedDataAndSize) {
InputBatchCache generic_set;
generic_set.AddData("Foo");
generic_set.AddData("Bar");
generic_set.GetAs<StringData>();
const std::vector<string> expected_data = {"Foo_converted", "Bar_converted"};
EXPECT_EQ(expected_data, generic_set.SerializedData());
EXPECT_EQ(2, generic_set.Size());
}
TEST(InputBatchCacheTest, InitializeFromInputBatch) {
const std::vector<string> kInputData = {"foo", "bar", "baz"};
const std::vector<string> kExpectedData = {"foo_converted", //
"bar_converted", //
"baz_converted"};
std::unique_ptr<StringData> string_data(new StringData());
string_data->SetData(kInputData);
const StringData *string_data_ptr = string_data.get();
InputBatchCache generic_set(std::move(string_data));
auto data = generic_set.GetAs<StringData>();
ExpectSameAddress(string_data_ptr, data);
EXPECT_EQ(data->GetSize(), 3);
EXPECT_EQ(data->GetSerializedData(), kExpectedData);
EXPECT_EQ(*data->data(), kExpectedData);
// AddData() shouldn't work since the cache is already populated.
EXPECT_DEATH(generic_set.AddData("YOU MAY NOT DO THIS AND IT WILL DIE."),
"after the cache has been converted");
// GetAs() shouldn't work with a different type.
EXPECT_DEATH(generic_set.GetAs<DifferentStringData>(),
"Attempted to convert to two object types!");
}
TEST(InputBatchCacheTest, CannotInitializeFromNullInputBatch) {
EXPECT_DEATH(InputBatchCache(std::unique_ptr<StringData>()),
"Cannot initialize from a null InputBatch");
}
} // namespace dragnn } // namespace dragnn
} // namespace syntaxnet } // namespace syntaxnet
...@@ -13,8 +13,8 @@ ...@@ -13,8 +13,8 @@
// limitations under the License. // limitations under the License.
// ============================================================================= // =============================================================================
#ifndef NLP_SAFT_OPENSOURCE_DRAGNN_CORE_INTERFACES_CLONEABLE_TRANSITION_STATE_H_ #ifndef DRAGNN_CORE_INTERFACES_CLONEABLE_TRANSITION_STATE_H_
#define NLP_SAFT_OPENSOURCE_DRAGNN_CORE_INTERFACES_CLONEABLE_TRANSITION_STATE_H_ #define DRAGNN_CORE_INTERFACES_CLONEABLE_TRANSITION_STATE_H_
#include <memory> #include <memory>
#include <vector> #include <vector>
...@@ -33,26 +33,32 @@ class CloneableTransitionState : public TransitionState { ...@@ -33,26 +33,32 @@ class CloneableTransitionState : public TransitionState {
public: public:
~CloneableTransitionState<T>() override {} ~CloneableTransitionState<T>() override {}
// Initialize this TransitionState from a previous TransitionState. The // Initializes this TransitionState from a previous TransitionState. The
// ParentBeamIndex is the location of that previous TransitionState in the // ParentBeamIndex is the location of that previous TransitionState in the
// provided beam. // provided beam.
void Init(const TransitionState &parent) override = 0; void Init(const TransitionState &parent) override = 0;
// Return the beam index of the state passed into the initializer of this // Returns the beam index of the state passed into the initializer of this
// TransitionState. // TransitionState.
const int ParentBeamIndex() const override = 0; int ParentBeamIndex() const override = 0;
// Get the current beam index for this state. // Gets the current beam index for this state.
const int GetBeamIndex() const override = 0; int GetBeamIndex() const override = 0;
// Set the current beam index for this state. // Sets the current beam index for this state.
void SetBeamIndex(const int index) override = 0; void SetBeamIndex(int index) override = 0;
// Get the score associated with this transition state. // Gets the score associated with this transition state.
const float GetScore() const override = 0; float GetScore() const override = 0;
// Set the score associated with this transition state. // Sets the score associated with this transition state.
void SetScore(const float score) override = 0; void SetScore(float score) override = 0;
// Gets the gold-ness of this state (whether it is on the oracle path)
bool IsGold() const override = 0;
// Sets the gold-ness of this state.
void SetGold(bool is_gold) override = 0;
// Depicts this state as an HTML-language string. // Depicts this state as an HTML-language string.
string HTMLRepresentation() const override = 0; string HTMLRepresentation() const override = 0;
...@@ -64,4 +70,4 @@ class CloneableTransitionState : public TransitionState { ...@@ -64,4 +70,4 @@ class CloneableTransitionState : public TransitionState {
} // namespace dragnn } // namespace dragnn
} // namespace syntaxnet } // namespace syntaxnet
#endif // NLP_SAFT_OPENSOURCE_DRAGNN_CORE_INTERFACES_CLONEABLE_TRANSITION_STATE_H_ #endif // DRAGNN_CORE_INTERFACES_CLONEABLE_TRANSITION_STATE_H_
...@@ -13,8 +13,8 @@ ...@@ -13,8 +13,8 @@
// limitations under the License. // limitations under the License.
// ============================================================================= // =============================================================================
#ifndef NLP_SAFT_OPENSOURCE_DRAGNN_CORE_INTERFACES_COMPONENT_H_ #ifndef DRAGNN_CORE_INTERFACES_COMPONENT_H_
#define NLP_SAFT_OPENSOURCE_DRAGNN_CORE_INTERFACES_COMPONENT_H_ #define DRAGNN_CORE_INTERFACES_COMPONENT_H_
#include <vector> #include <vector>
...@@ -83,11 +83,13 @@ class Component : public RegisterableClass<Component> { ...@@ -83,11 +83,13 @@ class Component : public RegisterableClass<Component> {
virtual std::function<int(int, int, int)> GetStepLookupFunction( virtual std::function<int(int, int, int)> GetStepLookupFunction(
const string &method) = 0; const string &method) = 0;
// Advances this component from the given transition matrix. // Advances this component from the given transition matrix, which is
virtual void AdvanceFromPrediction(const float transition_matrix[], // |num_items| x |num_actions|.
int transition_matrix_length) = 0; virtual bool AdvanceFromPrediction(const float *score_matrix, int num_items,
int num_actions) = 0;
// Advances this component from the state oracles. // Advances this component from the state oracles. There is no return from
// this, since it should always succeed.
virtual void AdvanceFromOracle() = 0; virtual void AdvanceFromOracle() = 0;
// Returns true if all states within this component are terminal. // Returns true if all states within this component are terminal.
...@@ -110,6 +112,14 @@ class Component : public RegisterableClass<Component> { ...@@ -110,6 +112,14 @@ class Component : public RegisterableClass<Component> {
// BulkFeatureExtractor object to contain the functors and other information. // BulkFeatureExtractor object to contain the functors and other information.
virtual int BulkGetFixedFeatures(const BulkFeatureExtractor &extractor) = 0; virtual int BulkGetFixedFeatures(const BulkFeatureExtractor &extractor) = 0;
// Directly computes the embedding matrix for all channels, advancing the
// component via the oracle until it is terminal. This call takes a vector
// of EmbeddingMatrix structs, one per channel, in channel order.
virtual void BulkEmbedFixedFeatures(
int batch_size_padding, int num_steps_padding, int output_array_size,
const vector<const float *> &per_channel_embeddings,
float *embedding_output) = 0;
// Extracts and returns the vector of LinkFeatures for the specified // Extracts and returns the vector of LinkFeatures for the specified
// channel. Note: these are NOT translated. // channel. Note: these are NOT translated.
virtual std::vector<LinkFeatures> GetRawLinkFeatures( virtual std::vector<LinkFeatures> GetRawLinkFeatures(
...@@ -138,4 +148,4 @@ class Component : public RegisterableClass<Component> { ...@@ -138,4 +148,4 @@ class Component : public RegisterableClass<Component> {
} // namespace dragnn } // namespace dragnn
} // namespace syntaxnet } // namespace syntaxnet
#endif // NLP_SAFT_OPENSOURCE_DRAGNN_CORE_INTERFACES_COMPONENT_H_ #endif // DRAGNN_CORE_INTERFACES_COMPONENT_H_
...@@ -13,8 +13,8 @@ ...@@ -13,8 +13,8 @@
// limitations under the License. // limitations under the License.
// ============================================================================= // =============================================================================
#ifndef NLP_SAFT_OPENSOURCE_DRAGNN_CORE_INTERFACES_INPUT_BATCH_H_ #ifndef DRAGNN_CORE_INTERFACES_INPUT_BATCH_H_
#define NLP_SAFT_OPENSOURCE_DRAGNN_CORE_INTERFACES_INPUT_BATCH_H_ #define DRAGNN_CORE_INTERFACES_INPUT_BATCH_H_
#include <string> #include <string>
#include <vector> #include <vector>
...@@ -32,14 +32,17 @@ class InputBatch { ...@@ -32,14 +32,17 @@ class InputBatch {
public: public:
virtual ~InputBatch() {} virtual ~InputBatch() {}
// Set the data to translate to the subclass' data type. // Sets the data to translate to the subclass' data type. Call at most once.
virtual void SetData(const std::vector<string> &data) = 0; virtual void SetData(const std::vector<string> &data) = 0;
// Translate the underlying data back to a vector of strings, as appropriate. // Returns the size of the batch.
virtual int GetSize() const = 0;
// Translates the underlying data back to a vector of strings, as appropriate.
virtual const std::vector<string> GetSerializedData() const = 0; virtual const std::vector<string> GetSerializedData() const = 0;
}; };
} // namespace dragnn } // namespace dragnn
} // namespace syntaxnet } // namespace syntaxnet
#endif // NLP_SAFT_OPENSOURCE_DRAGNN_CORE_INTERFACES_INPUT_BATCH_H_ #endif // DRAGNN_CORE_INTERFACES_INPUT_BATCH_H_
...@@ -13,8 +13,8 @@ ...@@ -13,8 +13,8 @@
// limitations under the License. // limitations under the License.
// ============================================================================= // =============================================================================
#ifndef NLP_SAFT_OPENSOURCE_DRAGNN_CORE_INTERFACES_TRANSITION_STATE_H_ #ifndef DRAGNN_CORE_INTERFACES_TRANSITION_STATE_H_
#define NLP_SAFT_OPENSOURCE_DRAGNN_CORE_INTERFACES_TRANSITION_STATE_H_ #define DRAGNN_CORE_INTERFACES_TRANSITION_STATE_H_
#include <memory> #include <memory>
#include <vector> #include <vector>
...@@ -44,19 +44,25 @@ class TransitionState { ...@@ -44,19 +44,25 @@ class TransitionState {
// Return the beam index of the state passed into the initializer of this // Return the beam index of the state passed into the initializer of this
// TransitionState. // TransitionState.
virtual const int ParentBeamIndex() const = 0; virtual int ParentBeamIndex() const = 0;
// Get the current beam index for this state. // Gets the current beam index for this state.
virtual const int GetBeamIndex() const = 0; virtual int GetBeamIndex() const = 0;
// Set the current beam index for this state. // Sets the current beam index for this state.
virtual void SetBeamIndex(const int index) = 0; virtual void SetBeamIndex(int index) = 0;
// Get the score associated with this transition state. // Gets the score associated with this transition state.
virtual const float GetScore() const = 0; virtual float GetScore() const = 0;
// Set the score associated with this transition state. // Sets the score associated with this transition state.
virtual void SetScore(const float score) = 0; virtual void SetScore(float score) = 0;
// Gets the gold-ness of this state (whether it is on the oracle path)
virtual bool IsGold() const = 0;
// Sets the gold-ness of this state.
virtual void SetGold(bool is_gold) = 0;
// Depicts this state as an HTML-language string. // Depicts this state as an HTML-language string.
virtual string HTMLRepresentation() const = 0; virtual string HTMLRepresentation() const = 0;
...@@ -65,4 +71,4 @@ class TransitionState { ...@@ -65,4 +71,4 @@ class TransitionState {
} // namespace dragnn } // namespace dragnn
} // namespace syntaxnet } // namespace syntaxnet
#endif // NLP_SAFT_OPENSOURCE_DRAGNN_CORE_INTERFACES_TRANSITION_STATE_H_ #endif // DRAGNN_CORE_INTERFACES_TRANSITION_STATE_H_
...@@ -13,8 +13,8 @@ ...@@ -13,8 +13,8 @@
// limitations under the License. // limitations under the License.
// ============================================================================= // =============================================================================
#ifndef NLP_SAFT_OPENSOURCE_DRAGNN_CORE_OPS_COMPUTE_SESSION_OP_H_ #ifndef DRAGNN_CORE_OPS_COMPUTE_SESSION_OP_H_
#define NLP_SAFT_OPENSOURCE_DRAGNN_CORE_OPS_COMPUTE_SESSION_OP_H_ #define DRAGNN_CORE_OPS_COMPUTE_SESSION_OP_H_
#include <string> #include <string>
...@@ -66,4 +66,4 @@ class ComputeSessionOp : public tensorflow::OpKernel { ...@@ -66,4 +66,4 @@ class ComputeSessionOp : public tensorflow::OpKernel {
} // namespace dragnn } // namespace dragnn
} // namespace syntaxnet } // namespace syntaxnet
#endif // NLP_SAFT_OPENSOURCE_DRAGNN_CORE_OPS_COMPUTE_SESSION_OP_H_ #endif // DRAGNN_CORE_OPS_COMPUTE_SESSION_OP_H_
...@@ -303,6 +303,73 @@ class BulkFixedEmbeddings : public ComputeSessionOp { ...@@ -303,6 +303,73 @@ class BulkFixedEmbeddings : public ComputeSessionOp {
REGISTER_KERNEL_BUILDER(Name("BulkFixedEmbeddings").Device(DEVICE_CPU), REGISTER_KERNEL_BUILDER(Name("BulkFixedEmbeddings").Device(DEVICE_CPU),
BulkFixedEmbeddings); BulkFixedEmbeddings);
// See docstring in dragnn_bulk_ops.cc.
class BulkEmbedFixedFeatures : public ComputeSessionOp {
public:
explicit BulkEmbedFixedFeatures(OpKernelConstruction *context)
: ComputeSessionOp(context) {
OP_REQUIRES_OK(context, context->GetAttr("num_channels", &num_channels_));
// The input vector's zeroth element is the state handle, and the remaining
// num_channels_ elements are tensors of float embeddings, one per channel.
vector<DataType> input_types(num_channels_ + 1, DT_FLOAT);
input_types[0] = DT_STRING;
const vector<DataType> output_types = {DT_STRING, DT_FLOAT, DT_INT32};
OP_REQUIRES_OK(context, context->MatchSignature(input_types, output_types));
OP_REQUIRES_OK(context, context->GetAttr("pad_to_batch", &pad_to_batch_));
OP_REQUIRES_OK(context, context->GetAttr("pad_to_steps", &pad_to_steps_));
}
bool OutputsHandle() const override { return true; }
bool RequiresComponentName() const override { return true; }
void ComputeWithState(OpKernelContext *context,
ComputeSession *session) override {
const auto &spec = session->Spec(component_name());
int embedding_size = 0;
std::vector<const float *> embeddings(num_channels_);
for (int channel = 0; channel < num_channels_; ++channel) {
const int embeddings_index = channel + 1;
embedding_size += context->input(embeddings_index).shape().dim_size(1) *
spec.fixed_feature(channel).size();
embeddings[channel] =
context->input(embeddings_index).flat<float>().data();
}
Tensor *embedding_vectors;
OP_REQUIRES_OK(context,
context->allocate_output(
1,
TensorShape({pad_to_steps_ * pad_to_batch_ *
session->BeamSize(component_name()),
embedding_size}),
&embedding_vectors));
Tensor *num_steps_tensor;
OP_REQUIRES_OK(context, context->allocate_output(2, TensorShape({}),
&num_steps_tensor));
embedding_vectors->flat<float>().setZero();
int output_size = embedding_vectors->NumElements();
session->BulkEmbedFixedFeatures(component_name(), pad_to_batch_,
pad_to_steps_, output_size, embeddings,
embedding_vectors->flat<float>().data());
num_steps_tensor->scalar<int32>()() = pad_to_steps_;
}
private:
// Number of fixed feature channels.
int num_channels_;
// Will pad output to this many batch elements.
int pad_to_batch_;
// Will pad output to this many steps.
int pad_to_steps_;
TF_DISALLOW_COPY_AND_ASSIGN(BulkEmbedFixedFeatures);
};
REGISTER_KERNEL_BUILDER(Name("BulkEmbedFixedFeatures").Device(DEVICE_CPU),
BulkEmbedFixedFeatures);
// See docstring in dragnn_bulk_ops.cc. // See docstring in dragnn_bulk_ops.cc.
class BulkAdvanceFromOracle : public ComputeSessionOp { class BulkAdvanceFromOracle : public ComputeSessionOp {
public: public:
...@@ -387,8 +454,11 @@ class BulkAdvanceFromPrediction : public ComputeSessionOp { ...@@ -387,8 +454,11 @@ class BulkAdvanceFromPrediction : public ComputeSessionOp {
} }
} }
if (!session->IsTerminal(component_name())) { if (!session->IsTerminal(component_name())) {
session->AdvanceFromPrediction(component_name(), scores_per_step.data(), bool success = session->AdvanceFromPrediction(
scores_per_step.size()); component_name(), scores_per_step.data(), num_items, num_actions);
OP_REQUIRES(
context, success,
tensorflow::errors::Internal("Unable to advance from prediction."));
} }
} }
} }
......
...@@ -375,6 +375,114 @@ TEST_F(DragnnBulkOpKernelsTest, BulkFixedEmbeddings) { ...@@ -375,6 +375,114 @@ TEST_F(DragnnBulkOpKernelsTest, BulkFixedEmbeddings) {
EXPECT_EQ(kNumSteps, GetOutput(2)->scalar<int32>()()); EXPECT_EQ(kNumSteps, GetOutput(2)->scalar<int32>()());
} }
TEST_F(DragnnBulkOpKernelsTest, BulkEmbedFixedFeatures) {
// Create and initialize the kernel under test.
constexpr int kBatchPad = 7;
constexpr int kStepPad = 5;
constexpr int kMaxBeamSize = 3;
TF_ASSERT_OK(
NodeDefBuilder("BulkEmbedFixedFeatures", "BulkEmbedFixedFeatures")
.Attr("component", kComponentName)
.Attr("num_channels", kNumChannels)
.Attr("pad_to_batch", kBatchPad)
.Attr("pad_to_steps", kStepPad)
.Input(FakeInput(DT_STRING)) // The handle for the ComputeSession.
.Input(FakeInput(DT_FLOAT)) // Embedding matrices.
.Finalize(node_def()));
MockComputeSession *mock_session = GetMockSession();
ComponentSpec spec;
spec.set_name(kComponentName);
auto chan0_spec = spec.add_fixed_feature();
constexpr int kChan0FeatureCount = 2;
chan0_spec->set_size(kChan0FeatureCount);
auto chan1_spec = spec.add_fixed_feature();
constexpr int kChan1FeatureCount = 1;
chan1_spec->set_size(kChan1FeatureCount);
EXPECT_CALL(*mock_session, Spec(kComponentName))
.WillOnce(testing::ReturnRef(spec));
EXPECT_CALL(*mock_session, BeamSize(kComponentName))
.WillOnce(testing::Return(kMaxBeamSize));
// Embedding matrices as additional inputs.
// For channel 0, the embeddings are [id, id, id].
// For channel 1, the embeddings are [id^2, id^2, id^2, ... ,id^2].
vector<float> embedding_matrix_0;
constexpr int kEmbedding0Size = 3;
vector<float> embedding_matrix_1;
constexpr int kEmbedding1Size = 9;
for (int id = 0; id < kNumIds; ++id) {
for (int i = 0; i < kEmbedding0Size; ++i) {
embedding_matrix_0.push_back(id);
LOG(INFO) << embedding_matrix_0.back();
}
for (int i = 0; i < kEmbedding1Size; ++i) {
embedding_matrix_1.push_back(id * id);
LOG(INFO) << embedding_matrix_0.back();
}
}
AddInputFromArray<float>(TensorShape({kNumIds, kEmbedding0Size}),
embedding_matrix_0);
AddInputFromArray<float>(TensorShape({kNumIds, kEmbedding1Size}),
embedding_matrix_1);
constexpr int kExpectedEmbeddingSize = kChan0FeatureCount * kEmbedding0Size +
kChan1FeatureCount * kEmbedding1Size;
constexpr int kExpectedOutputSize =
kExpectedEmbeddingSize * kBatchPad * kStepPad * kMaxBeamSize;
// This function takes the allocator functions passed into GetBulkFF, uses
// them to allocate a tensor, then fills that tensor based on channel.
auto eval_function = [=](const string &component_name, int batch_size_padding,
int num_steps_padding, int output_array_size,
const vector<const float *> &per_channel_embeddings,
float *embedding_output) {
// Validate the control variables.
EXPECT_EQ(batch_size_padding, kBatchPad);
EXPECT_EQ(num_steps_padding, kStepPad);
EXPECT_EQ(output_array_size, kExpectedOutputSize);
// Validate the passed embeddings.
for (int i = 0; i < kNumIds; ++i) {
for (int j = 0; j < kEmbedding0Size; ++j) {
float ch0_embedding =
per_channel_embeddings.at(0)[i * kEmbedding0Size + j];
EXPECT_FLOAT_EQ(ch0_embedding, i)
<< "Failed match at " << i << "," << j;
}
for (int j = 0; j < kEmbedding1Size; ++j) {
float ch1_embedding =
per_channel_embeddings.at(1)[i * kEmbedding1Size + j];
EXPECT_FLOAT_EQ(ch1_embedding, i * i)
<< "Failed match at " << i << "," << j;
}
}
// Fill the output matrix to the expected size. This will trigger msan
// if the allocation wasn't big enough.
for (int i = 0; i < kExpectedOutputSize; ++i) {
embedding_output[i] = i;
}
};
EXPECT_CALL(*mock_session,
BulkEmbedFixedFeatures(kComponentName, _, _, _, _, _))
.WillOnce(testing::Invoke(eval_function));
// Run the kernel.
TF_EXPECT_OK(RunOpKernelWithContext());
// Validate outputs.
EXPECT_EQ(kBatchPad * kStepPad * kMaxBeamSize,
GetOutput(1)->shape().dim_size(0));
EXPECT_EQ(kExpectedEmbeddingSize, GetOutput(1)->shape().dim_size(1));
auto output_data = GetOutput(1)->flat<float>();
for (int i = 0; i < kExpectedOutputSize; ++i) {
EXPECT_FLOAT_EQ(i, output_data(i));
}
EXPECT_EQ(kStepPad, GetOutput(2)->scalar<int32>()());
}
TEST_F(DragnnBulkOpKernelsTest, BulkFixedEmbeddingsWithPadding) { TEST_F(DragnnBulkOpKernelsTest, BulkFixedEmbeddingsWithPadding) {
// Create and initialize the kernel under test. // Create and initialize the kernel under test.
constexpr int kPaddedNumSteps = 5; constexpr int kPaddedNumSteps = 5;
...@@ -592,12 +700,54 @@ TEST_F(DragnnBulkOpKernelsTest, BulkAdvanceFromPrediction) { ...@@ -592,12 +700,54 @@ TEST_F(DragnnBulkOpKernelsTest, BulkAdvanceFromPrediction) {
EXPECT_CALL(*mock_session, EXPECT_CALL(*mock_session,
AdvanceFromPrediction(kComponentName, AdvanceFromPrediction(kComponentName,
CheckScoresAreConsecutiveIntegersDivTen(), CheckScoresAreConsecutiveIntegersDivTen(),
kNumItems * kNumActions)) kNumItems, kNumActions))
.Times(kNumSteps); .Times(kNumSteps)
.WillRepeatedly(Return(true));
// Run the kernel. // Run the kernel.
TF_EXPECT_OK(RunOpKernelWithContext()); TF_EXPECT_OK(RunOpKernelWithContext());
} }
TEST_F(DragnnBulkOpKernelsTest, BulkAdvanceFromPredictionFailsIfAdvanceFails) {
// Create and initialize the kernel under test.
TF_ASSERT_OK(
NodeDefBuilder("BulkAdvanceFromPrediction", "BulkAdvanceFromPrediction")
.Attr("component", kComponentName)
.Input(FakeInput(DT_STRING)) // The handle for the ComputeSession.
.Input(FakeInput(DT_FLOAT)) // Prediction scores for advancing.
.Finalize(node_def()));
MockComputeSession *mock_session = GetMockSession();
// Creates an input tensor such that each step will see a list of consecutive
// integers divided by 10 as scores.
vector<float> scores(kNumItems * kNumSteps * kNumActions);
for (int step(0), cnt(0); step < kNumSteps; ++step) {
for (int item = 0; item < kNumItems; ++item) {
for (int action = 0; action < kNumActions; ++action, ++cnt) {
scores[action + kNumActions * (step + item * kNumSteps)] = cnt / 10.0f;
}
}
}
AddInputFromArray<float>(TensorShape({kNumItems * kNumSteps, kNumActions}),
scores);
EXPECT_CALL(*mock_session, BeamSize(kComponentName)).WillOnce(Return(1));
EXPECT_CALL(*mock_session, BatchSize(kComponentName))
.WillOnce(Return(kNumItems));
EXPECT_CALL(*mock_session, IsTerminal(kComponentName))
.Times(2)
.WillRepeatedly(Return(false));
EXPECT_CALL(*mock_session,
AdvanceFromPrediction(kComponentName,
CheckScoresAreConsecutiveIntegersDivTen(),
kNumItems, kNumActions))
.WillOnce(Return(true))
.WillOnce(Return(false));
// Run the kernel.
auto result = RunOpKernelWithContext();
EXPECT_FALSE(result.ok());
}
} // namespace dragnn } // namespace dragnn
} // namespace syntaxnet } // namespace syntaxnet
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment