Commit 4364390a authored by Ivan Bogatyy's avatar Ivan Bogatyy Committed by calberti
Browse files

Release DRAGNN bulk networks (#2785)

* Release DRAGNN bulk networks
parent 638fd759
......@@ -33,8 +33,9 @@ cc_library(
name = "compute_session",
hdrs = ["compute_session.h"],
deps = [
":index_translator",
":input_batch_cache",
"//dragnn/components/util:bulk_feature_extractor",
"//dragnn/core:index_translator",
"//dragnn/core/interfaces:component",
"//dragnn/protos:spec_proto",
"//dragnn/protos:trace_proto",
......@@ -120,8 +121,10 @@ cc_test(
":compute_session",
":compute_session_impl",
":compute_session_pool",
":input_batch_cache",
"//dragnn/components/util:bulk_feature_extractor",
"//dragnn/core/interfaces:component",
"//dragnn/core/interfaces:input_batch",
"//dragnn/core/test:generic",
"//dragnn/core/test:mock_component",
"//dragnn/core/test:mock_transition_state",
......@@ -248,6 +251,7 @@ cc_library(
"//syntaxnet:base",
"@org_tensorflow//third_party/eigen3",
],
alwayslink = 1,
)
# Tensorflow kernel libraries, for use with unit tests.
......
......@@ -13,8 +13,8 @@
// limitations under the License.
// =============================================================================
#ifndef NLP_SAFT_OPENSOURCE_DRAGNN_CORE_BEAM_H_
#define NLP_SAFT_OPENSOURCE_DRAGNN_CORE_BEAM_H_
#ifndef DRAGNN_CORE_BEAM_H_
#define DRAGNN_CORE_BEAM_H_
#include <algorithm>
#include <cmath>
......@@ -43,19 +43,23 @@ class Beam {
static_assert(
std::is_base_of<CloneableTransitionState<T>, T>::value,
"This class must be instantiated to use a CloneableTransitionState");
track_gold_ = false;
}
// Sets whether or not the beam should track gold states.
void SetGoldTracking(bool track_gold) { track_gold_ = track_gold; }
// Sets the Beam functions, as follows:
// bool is_allowed(TransitionState *, int): Return true if transition 'int' is
// allowed for transition state 'TransitionState *'.
// void perform_transition(TransitionState *, int): Performs transition 'int'
// on transition state 'TransitionState *'.
// int oracle_function(TransitionState *): Returns the oracle-specified action
// for transition state 'TransitionState *'.
// vector<int> oracle_function(TransitionState *): Returns the oracle-
// specified actions for transition state 'TransitionState *'.
void SetFunctions(std::function<bool(T *, int)> is_allowed,
std::function<bool(T *)> is_final,
std::function<void(T *, int)> perform_transition,
std::function<int(T *)> oracle_function) {
std::function<vector<int>(T *)> oracle_function) {
is_allowed_ = is_allowed;
is_final_ = is_final;
perform_transition_ = perform_transition;
......@@ -74,12 +78,17 @@ class Beam {
for (int i = 0; i < beam_.size(); ++i) {
previous_beam_indices.at(i) = beam_[i]->ParentBeamIndex();
beam_[i]->SetBeamIndex(i);
// TODO(googleuser): Add gold tracking to component-level state creation.
if (!track_gold_) {
beam_[i]->SetGold(false);
}
}
beam_index_history_.emplace_back(previous_beam_indices);
}
// Advances the Beam from the given transition matrix.
void AdvanceFromPrediction(const float transition_matrix[], int matrix_length,
bool AdvanceFromPrediction(const float *transition_matrix, int matrix_length,
int num_actions) {
// Ensure that the transition matrix is the correct size. All underlying
// states should have the same transition profile, so using the one at 0
......@@ -89,91 +98,20 @@ class Beam {
"state transitions!";
if (max_size_ == 1) {
// In the case where beam size is 1, we can advance by simply finding the
// highest score and advancing the beam state in place.
VLOG(2) << "Beam size is 1. Using fast beam path.";
int best_action = -1;
float best_score = -INFINITY;
auto &state = beam_[0];
for (int action_idx = 0; action_idx < num_actions; ++action_idx) {
if (is_allowed_(state.get(), action_idx) &&
transition_matrix[action_idx] > best_score) {
best_score = transition_matrix[action_idx];
best_action = action_idx;
}
bool success = FastAdvanceFromPrediction(transition_matrix, num_actions);
if (!success) {
return false;
}
CHECK_GE(best_action, 0) << "Num actions: " << num_actions
<< " score[0]: " << transition_matrix[0];
perform_transition_(state.get(), best_action);
const float new_score = state->GetScore() + best_score;
state->SetScore(new_score);
state->SetBeamIndex(0);
} else {
// Create the vector of all possible transitions, along with their scores.
std::vector<Transition> candidates;
// Iterate through all beams, examining all actions for each beam.
for (int beam_idx = 0; beam_idx < beam_.size(); ++beam_idx) {
const auto &state = beam_[beam_idx];
for (int action_idx = 0; action_idx < num_actions; ++action_idx) {
// If the action is allowed, calculate the proposed new score and add
// the candidate action to the vector of all actions at this state.
if (is_allowed_(state.get(), action_idx)) {
Transition candidate;
// The matrix is laid out by beam index, with a linear set of
// actions for that index - so beam N's actions start at [nr. of
// actions]*[N].
const int matrix_idx = action_idx + beam_idx * num_actions;
CHECK_LT(matrix_idx, matrix_length)
<< "Matrix index out of bounds!";
const double score_delta = transition_matrix[matrix_idx];
CHECK(!std::isnan(score_delta));
candidate.source_idx = beam_idx;
candidate.action = action_idx;
candidate.resulting_score = state->GetScore() + score_delta;
candidates.emplace_back(candidate);
}
}
}
// Sort the vector of all possible transitions and scores.
const auto comparator = [](const Transition &a, const Transition &b) {
return a.resulting_score > b.resulting_score;
};
std::stable_sort(candidates.begin(), candidates.end(), comparator);
// Apply the top transitions, up to a maximum of 'max_size_'.
std::vector<std::unique_ptr<T>> new_beam;
std::vector<int> previous_beam_indices(max_size_, -1);
const int beam_size =
std::min(max_size_, static_cast<int>(candidates.size()));
VLOG(2) << "Previous beam size = " << beam_.size();
VLOG(2) << "New beam size = " << beam_size;
VLOG(2) << "Maximum beam size = " << max_size_;
for (int i = 0; i < beam_size; ++i) {
// Get the source of the i'th transition.
const auto &transition = candidates[i];
VLOG(2) << "Taking transition with score: "
<< transition.resulting_score
<< " and action: " << transition.action;
VLOG(2) << "transition.source_idx = " << transition.source_idx;
const auto &source = beam_[transition.source_idx];
// Put the new transition on the new state beam.
auto new_state = source->Clone();
perform_transition_(new_state.get(), transition.action);
new_state->SetScore(transition.resulting_score);
new_state->SetBeamIndex(i);
previous_beam_indices.at(i) = transition.source_idx;
new_beam.emplace_back(std::move(new_state));
bool success = BeamAdvanceFromPrediction(transition_matrix, matrix_length,
num_actions);
if (!success) {
return false;
}
beam_ = std::move(new_beam);
beam_index_history_.emplace_back(previous_beam_indices);
}
++num_steps_;
return true;
}
// Advances the Beam from the state oracles.
......@@ -182,7 +120,10 @@ class Beam {
for (int i = 0; i < beam_.size(); ++i) {
previous_beam_indices.at(i) = i;
if (is_final_(beam_[i].get())) continue;
const auto oracle_label = oracle_function_(beam_[i].get());
// There will always be at least one oracular transition, and taking the
// first returned transition is never worse than any other option.
const int oracle_label = oracle_function_(beam_[i].get()).at(0);
VLOG(2) << "AdvanceFromOracle beam_index:" << i
<< " oracle_label:" << oracle_label;
perform_transition_(beam_[i].get(), oracle_label);
......@@ -312,19 +253,180 @@ class Beam {
// Returns the current size of the beam.
const int size() const { return beam_.size(); }
// Returns true if at least one of the states in the beam is gold.
bool ContainsGold() {
if (!track_gold_) {
return false;
}
for (const auto &state : beam_) {
if (state->IsGold()) {
return true;
}
}
return false;
}
private:
// Associates an action taken on an index into current_state_ with a score.
friend void BM_FastAdvance(int num_iters, int num_transitions);
friend void BM_BeamAdvance(int num_iters, int num_transitions,
int max_beam_size);
// Associates an action taken with its source index.
struct Transition {
// The index of the source item.
int source_idx;
// The index of the action being taken.
int action;
// The score of the full derivation.
double resulting_score;
};
// In the case where beam size is 1, we can advance by simply finding the
// highest score and advancing the beam state in place.
bool FastAdvanceFromPrediction(const float *transition_matrix,
int num_actions) {
CHECK_EQ(1, max_size_)
<< "Using fast advance on invalid beam. This should never happen.";
VLOG(2) << "Beam size is 1. Using fast beam path.";
constexpr int kNoActionChosen = -1;
int best_action = kNoActionChosen;
float best_score = -INFINITY;
auto &state = beam_[0];
for (int action_idx = 0; action_idx < num_actions; ++action_idx) {
if (std::isnan(transition_matrix[action_idx])) {
LOG(ERROR) << "Found a NaN in the transition matrix! Unable to "
"continue. Num actions: "
<< num_actions << " index: " << action_idx;
return false;
}
if (is_allowed_(state.get(), action_idx) &&
transition_matrix[action_idx] > best_score) {
best_score = transition_matrix[action_idx];
best_action = action_idx;
}
}
if (best_action == kNoActionChosen) {
LOG(ERROR) << "No action was chosen! Unable to continue. Num actions: "
<< num_actions << " score[0]: " << transition_matrix[0];
return false;
}
bool is_gold = false;
if (track_gold_ && state->IsGold()) {
for (const auto &gold_transition : oracle_function_(state.get())) {
VLOG(3) << "Examining gold transition " << gold_transition
<< " for source index 1";
if (gold_transition == best_action) {
is_gold = true;
break;
}
}
}
perform_transition_(state.get(), best_action);
const float new_score = state->GetScore() + best_score;
state->SetScore(new_score);
state->SetBeamIndex(0);
state->SetGold(is_gold);
return true;
}
// In case the beam size is greater than 1, we need to advance using
// standard beam search.
bool BeamAdvanceFromPrediction(const float *transition_matrix,
int matrix_length, int num_actions) {
VLOG(2) << "Beam size is " << max_size_ << ". Using standard beam search.";
// Keep the multimap sorted high to low. The sort order for
// identical keys is stable.
std::multimap<float, Transition, std::greater<float>> candidates;
float threshold = -INFINITY;
// Iterate through all beams, examining all actions for each beam.
for (int beam_idx = 0; beam_idx < beam_.size(); ++beam_idx) {
const auto &state = beam_[beam_idx];
const float score = state->GetScore();
for (int action_idx = 0; action_idx < num_actions; ++action_idx) {
if (is_allowed_(state.get(), action_idx)) {
// The matrix is laid out by beam index, with a linear set of
// actions for that index - so beam N's actions start at [nr. of
// actions]*[N].
const int matrix_idx = action_idx + beam_idx * num_actions;
CHECK_LT(matrix_idx, matrix_length) << "Matrix index out of bounds!";
const float resulting_score = score + transition_matrix[matrix_idx];
if (std::isnan(resulting_score)) {
LOG(ERROR) << "Resulting score was a NaN! Unable to continue. Num "
"actions: "
<< num_actions << " action index " << action_idx;
return false;
}
if (candidates.size() == max_size_) {
// If the new score is lower than the bottom of the beam, move on.
if (resulting_score < threshold) {
continue;
}
// Otherwise, remove the bottom of the beam, making space
// for the new candidate.
candidates.erase(std::prev(candidates.end()));
}
// Add the new candidate, and update the threshold score.
const Transition candidate{beam_idx, action_idx};
candidates.emplace(resulting_score, candidate);
threshold = candidates.rbegin()->first;
}
}
}
// Apply the top transitions, up to a maximum of 'max_size_'.
std::vector<std::unique_ptr<T>> new_beam;
std::vector<int> previous_beam_indices(max_size_, -1);
const int beam_size = candidates.size();
new_beam.reserve(max_size_);
VLOG(2) << "Previous beam size = " << beam_.size();
VLOG(2) << "New beam size = " << beam_size;
VLOG(2) << "Maximum beam size = " << max_size_;
auto candidate_iterator = candidates.cbegin();
for (int i = 0; i < beam_size; ++i) {
// Get the score and source of the i'th transition.
const float resulting_score = candidate_iterator->first;
const auto &transition = candidate_iterator->second;
++candidate_iterator;
VLOG(2) << "Taking transition with score: " << resulting_score
<< " and action: " << transition.action;
VLOG(2) << "transition.source_idx = " << transition.source_idx;
const auto &source = beam_[transition.source_idx];
// Determine if the transition being taken will result in a gold state.
bool is_gold = false;
if (track_gold_ && source->IsGold()) {
for (const auto &gold_transition : oracle_function_(source.get())) {
VLOG(3) << "Examining gold transition " << gold_transition
<< " for source index " << transition.source_idx;
if (gold_transition == transition.action) {
VLOG(2) << "State from index " << transition.source_idx
<< " is gold.";
is_gold = true;
break;
}
}
}
VLOG(2) << "Gold examination complete for source index "
<< transition.source_idx;
// Put the new transition on the new state beam.
auto new_state = source->Clone();
perform_transition_(new_state.get(), transition.action);
new_state->SetScore(resulting_score);
new_state->SetBeamIndex(i);
new_state->SetGold(is_gold);
previous_beam_indices.at(i) = transition.source_idx;
new_beam.emplace_back(std::move(new_state));
}
beam_ = std::move(new_beam);
beam_index_history_.emplace_back(previous_beam_indices);
return true;
}
// The maximum beam size.
int max_size_;
......@@ -341,7 +443,7 @@ class Beam {
std::function<void(T *, int)> perform_transition_;
// Function to provide the oracle action for a given state.
std::function<int(T *)> oracle_function_;
std::function<vector<int>(T *)> oracle_function_;
// The history of the states in this beam. The vector indexes across steps.
// For every step, there is a vector in the vector. This inner vector denotes
......@@ -355,9 +457,12 @@ class Beam {
// The number of steps taken so far.
int num_steps_;
// Whether to track golden states.
bool track_gold_;
};
} // namespace dragnn
} // namespace syntaxnet
#endif // NLP_SAFT_OPENSOURCE_DRAGNN_CORE_BEAM_H_
#endif // DRAGNN_CORE_BEAM_H_
This diff is collapsed.
......@@ -13,12 +13,17 @@
// limitations under the License.
// =============================================================================
#ifndef NLP_SAFT_OPENSOURCE_DRAGNN_CORE_COMPONENT_REGISTRY_H_
#define NLP_SAFT_OPENSOURCE_DRAGNN_CORE_COMPONENT_REGISTRY_H_
#ifndef DRAGNN_CORE_COMPONENT_REGISTRY_H_
#define DRAGNN_CORE_COMPONENT_REGISTRY_H_
#include "dragnn/core/interfaces/component.h"
#include "syntaxnet/registry.h"
namespace syntaxnet {
// Class registry for DRAGNN components.
DECLARE_SYNTAXNET_CLASS_REGISTRY("DRAGNN Component", dragnn::Component);
} // namespace syntaxnet
// Macro to add a component to the registry. This macro associates a class with
// its class name as a string, so FooComponent would be associated with the
// string "FooComponent".
......@@ -26,4 +31,4 @@
REGISTER_SYNTAXNET_CLASS_COMPONENT(syntaxnet::dragnn::Component, #component, \
component)
#endif // NLP_SAFT_OPENSOURCE_DRAGNN_CORE_COMPONENT_REGISTRY_H_
#endif // DRAGNN_CORE_COMPONENT_REGISTRY_H_
......@@ -13,13 +13,14 @@
// limitations under the License.
// =============================================================================
#ifndef NLP_SAFT_OPENSOURCE_DRAGNN_CORE_COMPUTE_SESSION_H_
#define NLP_SAFT_OPENSOURCE_DRAGNN_CORE_COMPUTE_SESSION_H_
#ifndef DRAGNN_CORE_COMPUTE_SESSION_H_
#define DRAGNN_CORE_COMPUTE_SESSION_H_
#include <string>
#include "dragnn/components/util/bulk_feature_extractor.h"
#include "dragnn/core/index_translator.h"
#include "dragnn/core/input_batch_cache.h"
#include "dragnn/core/interfaces/component.h"
#include "dragnn/protos/spec.pb.h"
#include "dragnn/protos/trace.pb.h"
......@@ -64,10 +65,11 @@ class ComputeSession {
// Advance the given component using the component's oracle.
virtual void AdvanceFromOracle(const string &component_name) = 0;
// Advance the given component using the given score matrix.
virtual void AdvanceFromPrediction(const string &component_name,
const float score_matrix[],
int score_matrix_length) = 0;
// Advance the given component using the given score matrix, which is
// |num_items| x |num_actions|.
virtual bool AdvanceFromPrediction(const string &component_name,
const float *score_matrix, int num_items,
int num_actions) = 0;
// Get the input features for the given component and channel. This passes
// through to the relevant Component's GetFixedFeatures() call.
......@@ -84,6 +86,15 @@ class ComputeSession {
virtual int BulkGetInputFeatures(const string &component_name,
const BulkFeatureExtractor &extractor) = 0;
// Directly computes the embedding matrix for all channels, advancing the
// component via the oracle until it is terminal. This call takes a vector
// of float embedding matrices, one per channel, in channel order.
virtual void BulkEmbedFixedFeatures(
const string &component_name, int batch_size_padding,
int num_steps_padding, int output_array_size,
const vector<const float *> &per_channel_embeddings,
float *embedding_output) = 0;
// Get the input features for the given component and channel. This function
// can return empty LinkFeatures protos, which represent unused padding slots
// in the output weight tensor.
......@@ -111,6 +122,10 @@ class ComputeSession {
// Provides the ComputeSession with a batch of data to compute.
virtual void SetInputData(const std::vector<string> &data) = 0;
// Like SetInputData(), but accepts an InputBatchCache directly, potentially
// bypassing de-serialization.
virtual void SetInputBatchCache(std::unique_ptr<InputBatchCache> batch) = 0;
// Resets all components owned by this ComputeSession.
virtual void ResetSession() = 0;
......@@ -127,9 +142,14 @@ class ComputeSession {
// validate correct construction of translators in tests.
virtual const std::vector<const IndexTranslator *> Translators(
const string &component_name) const = 0;
// Get a given component. CHECK-fail if the component's IsReady method
// returns false.
virtual Component *GetReadiedComponent(
const string &component_name) const = 0;
};
} // namespace dragnn
} // namespace syntaxnet
#endif // NLP_SAFT_OPENSOURCE_DRAGNN_CORE_COMPUTE_SESSION_H_
#endif // DRAGNN_CORE_COMPUTE_SESSION_H_
......@@ -161,11 +161,11 @@ void ComputeSessionImpl::AdvanceFromOracle(const string &component_name) {
GetReadiedComponent(component_name)->AdvanceFromOracle();
}
void ComputeSessionImpl::AdvanceFromPrediction(const string &component_name,
const float score_matrix[],
int score_matrix_length) {
GetReadiedComponent(component_name)
->AdvanceFromPrediction(score_matrix, score_matrix_length);
bool ComputeSessionImpl::AdvanceFromPrediction(const string &component_name,
const float *score_matrix,
int num_items, int num_actions) {
return GetReadiedComponent(component_name)
->AdvanceFromPrediction(score_matrix, num_items, num_actions);
}
int ComputeSessionImpl::GetInputFeatures(
......@@ -182,6 +182,16 @@ int ComputeSessionImpl::BulkGetInputFeatures(
return GetReadiedComponent(component_name)->BulkGetFixedFeatures(extractor);
}
void ComputeSessionImpl::BulkEmbedFixedFeatures(
const string &component_name, int batch_size_padding, int num_steps_padding,
int output_array_size, const vector<const float *> &per_channel_embeddings,
float *embedding_output) {
return GetReadiedComponent(component_name)
->BulkEmbedFixedFeatures(batch_size_padding, num_steps_padding,
output_array_size, per_channel_embeddings,
embedding_output);
}
std::vector<LinkFeatures> ComputeSessionImpl::GetTranslatedLinkFeatures(
const string &component_name, int channel_id) {
auto *component = GetReadiedComponent(component_name);
......@@ -288,6 +298,11 @@ void ComputeSessionImpl::SetInputData(const std::vector<string> &data) {
input_data_.reset(new InputBatchCache(data));
}
void ComputeSessionImpl::SetInputBatchCache(
std::unique_ptr<InputBatchCache> batch) {
input_data_ = std::move(batch);
}
void ComputeSessionImpl::ResetSession() {
// Reset all component states.
for (auto &component_pair : components_) {
......@@ -308,6 +323,7 @@ const std::vector<const IndexTranslator *> ComputeSessionImpl::Translators(
const string &component_name) const {
auto translators = GetTranslators(component_name);
std::vector<const IndexTranslator *> const_translators;
const_translators.reserve(translators.size());
for (const auto &translator : translators) {
const_translators.push_back(translator);
}
......
......@@ -13,8 +13,8 @@
// limitations under the License.
// =============================================================================
#ifndef NLP_SAFT_OPENSOURCE_DRAGNN_CORE_COMPUTE_SESSION_IMPL_H_
#define NLP_SAFT_OPENSOURCE_DRAGNN_CORE_COMPUTE_SESSION_IMPL_H_
#ifndef DRAGNN_CORE_COMPUTE_SESSION_IMPL_H_
#define DRAGNN_CORE_COMPUTE_SESSION_IMPL_H_
#include <memory>
......@@ -55,9 +55,9 @@ class ComputeSessionImpl : public ComputeSession {
void AdvanceFromOracle(const string &component_name) override;
void AdvanceFromPrediction(const string &component_name,
const float score_matrix[],
int score_matrix_length) override;
bool AdvanceFromPrediction(const string &component_name,
const float *score_matrix, int num_items,
int num_actions) override;
int GetInputFeatures(const string &component_name,
std::function<int32 *(int)> allocate_indices,
......@@ -68,6 +68,12 @@ class ComputeSessionImpl : public ComputeSession {
int BulkGetInputFeatures(const string &component_name,
const BulkFeatureExtractor &extractor) override;
void BulkEmbedFixedFeatures(
const string &component_name, int batch_size_padding,
int num_steps_padding, int output_array_size,
const vector<const float *> &per_channel_embeddings,
float *embedding_output) override;
std::vector<LinkFeatures> GetTranslatedLinkFeatures(
const string &component_name, int channel_id) override;
......@@ -84,6 +90,8 @@ class ComputeSessionImpl : public ComputeSession {
void SetInputData(const std::vector<string> &data) override;
void SetInputBatchCache(std::unique_ptr<InputBatchCache> batch) override;
void ResetSession() override;
void SetTracing(bool tracing_on) override;
......@@ -95,14 +103,14 @@ class ComputeSessionImpl : public ComputeSession {
const std::vector<const IndexTranslator *> Translators(
const string &component_name) const override;
// Get a given component. CHECK-fail if the component's IsReady method
// returns false.
Component *GetReadiedComponent(const string &component_name) const override;
private:
// Get a given component. Fails if the component is not found.
Component *GetComponent(const string &component_name) const;
// Get a given component. CHECK-fail if the component's IsReady method
// returns false.
Component *GetReadiedComponent(const string &component_name) const;
// Get the index translators for the given component.
const std::vector<IndexTranslator *> &GetTranslators(
const string &component_name) const;
......@@ -154,4 +162,4 @@ class ComputeSessionImpl : public ComputeSession {
} // namespace dragnn
} // namespace syntaxnet
#endif // NLP_SAFT_OPENSOURCE_DRAGNN_CORE_COMPUTE_SESSION_IMPL_H_
#endif // DRAGNN_CORE_COMPUTE_SESSION_IMPL_H_
......@@ -22,7 +22,9 @@
#include "dragnn/core/component_registry.h"
#include "dragnn/core/compute_session.h"
#include "dragnn/core/compute_session_pool.h"
#include "dragnn/core/input_batch_cache.h"
#include "dragnn/core/interfaces/component.h"
#include "dragnn/core/interfaces/input_batch.h"
#include "dragnn/core/test/generic.h"
#include "dragnn/core/test/mock_component.h"
#include "dragnn/core/test/mock_transition_state.h"
......@@ -65,8 +67,10 @@ class TestComponentType1 : public Component {
int GetSourceBeamIndex(int current_index, int batch) const override {
return 0;
}
void AdvanceFromPrediction(const float transition_matrix[],
int matrix_length) override {}
bool AdvanceFromPrediction(const float *score_matrix, int num_items,
int num_actions) override {
return true;
}
void AdvanceFromOracle() override {}
bool IsTerminal() const override { return true; }
std::function<int(int, int, int)> GetStepLookupFunction(
......@@ -83,6 +87,10 @@ class TestComponentType1 : public Component {
int channel_id) const override {
return 0;
}
void BulkEmbedFixedFeatures(
int batch_size_padding, int num_steps_padding, int embedding_size,
const vector<const float *> &per_channel_embeddings,
float *embedding_output) override {}
int BulkGetFixedFeatures(const BulkFeatureExtractor &extractor) override {
return 0;
}
......@@ -133,8 +141,10 @@ class TestComponentType2 : public Component {
int GetSourceBeamIndex(int current_index, int batch) const override {
return 0;
}
void AdvanceFromPrediction(const float transition_matrix[],
int matrix_length) override {}
bool AdvanceFromPrediction(const float *score_matrix, int num_items,
int num_actions) override {
return true;
}
void AdvanceFromOracle() override {}
bool IsTerminal() const override { return true; }
std::function<int(int, int, int)> GetStepLookupFunction(
......@@ -151,6 +161,10 @@ class TestComponentType2 : public Component {
int channel_id) const override {
return 0;
}
void BulkEmbedFixedFeatures(
int batch_size_padding, int num_steps_padding, int embedding_size,
const vector<const float *> &per_channel_embeddings,
float *embedding_output) override {}
int BulkGetFixedFeatures(const BulkFeatureExtractor &extractor) override {
return 0;
}
......@@ -201,8 +215,14 @@ class UnreadyComponent : public Component {
int GetSourceBeamIndex(int current_index, int batch) const override {
return 0;
}
void AdvanceFromPrediction(const float transition_matrix[],
int matrix_length) override {}
bool AdvanceFromPrediction(const float *score_matrix, int num_items,
int num_actions) override {
return true;
}
void BulkEmbedFixedFeatures(
int batch_size_padding, int num_steps_padding, int embedding_size,
const vector<const float *> &per_channel_embeddings,
float *embedding_output) override {}
void AdvanceFromOracle() override {}
bool IsTerminal() const override { return false; }
std::function<int(int, int, int)> GetStepLookupFunction(
......@@ -254,6 +274,18 @@ class ComputeSessionImplTestPoolAccessor {
}
};
// An InputBatch that uses the serialized data directly.
class IdentityBatch : public InputBatch {
public:
// Implements InputBatch.
void SetData(const std::vector<string> &data) override { data_ = data; }
int GetSize() const override { return data_.size(); }
const std::vector<string> GetSerializedData() const override { return data_; }
private:
std::vector<string> data_; // the batch data
};
// *****************************************************************************
// Tests begin here.
// *****************************************************************************
......@@ -739,7 +771,7 @@ TEST(ComputeSessionImplTest, InitializesComponentWithSource) {
EXPECT_CALL(*mock_components["component_one"], GetBeam())
.WillOnce(Return(beam));
// Expect that the second component will recieve that beam.
// Expect that the second component will receive that beam.
EXPECT_CALL(*mock_components["component_two"],
InitializeData(beam, kMaxBeamSize, NotNull()));
......@@ -899,7 +931,7 @@ TEST(ComputeSessionImplTest, SetTracingPropagatesToAllComponents) {
EXPECT_CALL(*mock_components["component_one"], GetBeam())
.WillOnce(Return(beam));
// Expect that the second component will recieve that beam, and then its
// Expect that the second component will receive that beam, and then its
// tracing will be initialized.
EXPECT_CALL(*mock_components["component_two"],
InitializeData(beam, kMaxBeamSize, NotNull()));
......@@ -1084,12 +1116,12 @@ TEST(ComputeSessionImplTest, InterfacePassesThrough) {
session->AdvanceFromOracle("component_one");
// AdvanceFromPrediction()
constexpr int kScoreMatrixLength = 3;
const float score_matrix[kScoreMatrixLength] = {1.0, 2.3, 4.5};
const int kNumActions = 1;
const float score_matrix[] = {1.0, 2.3, 4.5};
EXPECT_CALL(*mock_components["component_one"],
AdvanceFromPrediction(score_matrix, kScoreMatrixLength));
session->AdvanceFromPrediction("component_one", score_matrix,
kScoreMatrixLength);
AdvanceFromPrediction(score_matrix, batch_size, kNumActions));
session->AdvanceFromPrediction("component_one", score_matrix, batch_size,
kNumActions);
// GetFixedFeatures
auto allocate_indices = [](int size) -> int32 * { return nullptr; };
......@@ -1109,6 +1141,11 @@ TEST(ComputeSessionImplTest, InterfacePassesThrough) {
.WillOnce(Return(0));
EXPECT_EQ(0, session->BulkGetInputFeatures("component_one", extractor));
// BulkEmbedFixedFeatures
EXPECT_CALL(*mock_components["component_one"],
BulkEmbedFixedFeatures(1, 2, 3, _, _));
session->BulkEmbedFixedFeatures("component_one", 1, 2, 3, {nullptr}, nullptr);
// EmitOracleLabels()
std::vector<std::vector<int>> oracle_labels = {{0, 1}, {2, 3}};
EXPECT_CALL(*mock_components["component_one"], GetOracleLabels())
......@@ -1154,7 +1191,7 @@ TEST(ComputeSessionImplTest, InterfaceRequiresReady) {
constexpr int kScoreMatrixLength = 3;
const float score_matrix[kScoreMatrixLength] = {1.0, 2.3, 4.5};
EXPECT_DEATH(session->AdvanceFromPrediction("component_one", score_matrix,
kScoreMatrixLength),
kScoreMatrixLength, 1),
"without first initializing it");
constexpr int kArbitraryChannelId = 3;
EXPECT_DEATH(session->GetInputFeatures("component_one", nullptr, nullptr,
......@@ -1163,10 +1200,32 @@ TEST(ComputeSessionImplTest, InterfaceRequiresReady) {
BulkFeatureExtractor extractor(nullptr, nullptr, nullptr, false, 0, 0);
EXPECT_DEATH(session->BulkGetInputFeatures("component_one", extractor),
"without first initializing it");
EXPECT_DEATH(session->BulkEmbedFixedFeatures("component_one", 0, 0, 0,
{nullptr}, nullptr),
"without first initializing it");
EXPECT_DEATH(
session->GetTranslatedLinkFeatures("component_one", kArbitraryChannelId),
"without first initializing it");
}
TEST(ComputeSessionImplTest, SetInputBatchCache) {
// Use empty protos since we won't interact with components.
MasterSpec spec;
GridPoint hyperparams;
ComputeSessionPool pool(spec, hyperparams);
auto session = pool.GetSession();
// Initialize a cached IdentityBatch.
const std::vector<string> data = {"foo", "bar", "baz"};
std::unique_ptr<InputBatchCache> input_batch_cache(new InputBatchCache(data));
input_batch_cache->GetAs<IdentityBatch>();
// Inject the cache into the session.
session->SetInputBatchCache(std::move(input_batch_cache));
// Check that the injected batch can be retrieved.
EXPECT_EQ(session->GetSerializedPredictions(), data);
}
} // namespace dragnn
} // namespace syntaxnet
......@@ -13,8 +13,8 @@
// limitations under the License.
// =============================================================================
#ifndef NLP_SAFT_OPENSOURCE_DRAGNN_CORE_COMPUTE_SESSION_POOL_H_
#define NLP_SAFT_OPENSOURCE_DRAGNN_CORE_COMPUTE_SESSION_POOL_H_
#ifndef DRAGNN_CORE_COMPUTE_SESSION_POOL_H_
#define DRAGNN_CORE_COMPUTE_SESSION_POOL_H_
#include <memory>
......@@ -29,14 +29,14 @@ namespace dragnn {
class ComputeSessionPool {
public:
// Create a ComputeSessionPool that creates ComputeSessions for the given
// Creates a ComputeSessionPool that creates ComputeSessions for the given
// MasterSpec and hyperparameters.
ComputeSessionPool(const MasterSpec &master_spec,
const GridPoint &hyperparams);
virtual ~ComputeSessionPool();
// Get a ComputeSession. This function will attempt to use an already-created
// Gets a ComputeSession. This function will attempt to use an already-created
// ComputeSession, but if none are available a new one will be created.
std::unique_ptr<ComputeSession> GetSession();
......@@ -49,6 +49,12 @@ class ComputeSessionPool {
return num_unique_sessions_ - sessions_.size();
}
// Returns the number of unique sessions that have been created.
int num_unique_sessions() { return num_unique_sessions_; }
// Returns a reference to the underlying spec for this pool.
const MasterSpec &GetSpec() const { return master_spec_; }
private:
friend class ComputeSessionImplTestPoolAccessor;
friend class ComputeSessionPoolTestPoolAccessor;
......@@ -99,4 +105,4 @@ class ComputeSessionPool {
} // namespace dragnn
} // namespace syntaxnet
#endif // NLP_SAFT_OPENSOURCE_DRAGNN_CORE_COMPUTE_SESSION_POOL_H_
#endif // DRAGNN_CORE_COMPUTE_SESSION_POOL_H_
......@@ -207,6 +207,7 @@ TEST(ComputeSessionPoolTest, SupportsMultithreadedAccess) {
std::vector<std::unique_ptr<tensorflow::Thread>> request_threads;
constexpr int kNumThreadsToTest = 100;
request_threads.reserve(kNumThreadsToTest);
for (int i = 0; i < kNumThreadsToTest; ++i) {
request_threads.push_back(std::unique_ptr<tensorflow::Thread>(
tensorflow::Env::Default()->StartThread(
......
......@@ -13,8 +13,8 @@
// limitations under the License.
// =============================================================================
#ifndef NLP_SAFT_OPENSOURCE_DRAGNN_CORE_INDEX_TRANSLATOR_H_
#define NLP_SAFT_OPENSOURCE_DRAGNN_CORE_INDEX_TRANSLATOR_H_
#ifndef DRAGNN_CORE_INDEX_TRANSLATOR_H_
#define DRAGNN_CORE_INDEX_TRANSLATOR_H_
#include <memory>
#include <vector>
......@@ -80,4 +80,4 @@ class IndexTranslator {
} // namespace dragnn
} // namespace syntaxnet
#endif // NLP_SAFT_OPENSOURCE_DRAGNN_CORE_INDEX_TRANSLATOR_H_
#endif // DRAGNN_CORE_INDEX_TRANSLATOR_H_
......@@ -13,12 +13,15 @@
// limitations under the License.
// =============================================================================
#ifndef NLP_SAFT_OPENSOURCE_DRAGNN_CORE_INPUT_BATCH_CACHE_H_
#define NLP_SAFT_OPENSOURCE_DRAGNN_CORE_INPUT_BATCH_CACHE_H_
#ifndef DRAGNN_CORE_INPUT_BATCH_CACHE_H_
#define DRAGNN_CORE_INPUT_BATCH_CACHE_H_
#include <memory>
#include <string>
#include <type_traits>
#include <typeindex>
#include <typeinfo>
#include <utility>
#include "dragnn/core/interfaces/input_batch.h"
#include "tensorflow/core/platform/logging.h"
......@@ -42,6 +45,18 @@ class InputBatchCache {
explicit InputBatchCache(const std::vector<string> &data)
: stored_type_(std::type_index(typeid(void))), source_data_(data) {}
// Creates a InputBatchCache from the |batch|. InputBatchSubclass must be a
// strict subclass of InputBatch, and |batch| must be non-null. All calls to
// GetAs must match InputBatchSubclass.
template <class InputBatchSubclass>
explicit InputBatchCache(std::unique_ptr<InputBatchSubclass> batch)
: stored_type_(std::type_index(typeid(InputBatchSubclass))),
converted_data_(std::move(batch)) {
static_assert(IsStrictInputBatchSubclass<InputBatchSubclass>(),
"InputBatchCache requires a strict subclass of InputBatch");
CHECK(converted_data_) << "Cannot initialize from a null InputBatch";
}
// Adds a single string to the cache. Only useable before GetAs() has been
// called.
void AddData(const string &data) {
......@@ -52,10 +67,14 @@ class InputBatchCache {
}
// Converts the stored strings into protos and return them in a specific
// InputBatch subclass. T should always be of type InputBatch. After this
// method is called once, all further calls must be of the same data type.
// InputBatch subclass. T should always be a strict subclass of InputBatch.
// After this method is called once, all further calls must be of the same
// data type.
template <class T>
T *GetAs() {
static_assert(
IsStrictInputBatchSubclass<T>(),
"GetAs<T>() requires that T is a strict subclass of InputBatch");
if (!converted_data_) {
stored_type_ = std::type_index(typeid(T));
converted_data_.reset(new T());
......@@ -69,14 +88,27 @@ class InputBatchCache {
return dynamic_cast<T *>(converted_data_.get());
}
// Returns the size of the batch. Requires that GetAs() has been called.
int Size() const {
CHECK(converted_data_) << "Cannot return batch size without data.";
return converted_data_->GetSize();
}
// Returns the serialized representation of the data held in the input batch
// object within this cache.
// object within this cache. Requires that GetAs() has been called.
const std::vector<string> SerializedData() const {
CHECK(converted_data_) << "Cannot return batch without data.";
return converted_data_->GetSerializedData();
}
private:
// Returns true if InputBatchSubclass is a strict subclass of InputBatch.
template <class InputBatchSubclass>
static constexpr bool IsStrictInputBatchSubclass() {
return std::is_base_of<InputBatch, InputBatchSubclass>::value &&
!std::is_same<InputBatch, InputBatchSubclass>::value;
}
// The typeid of the stored data.
std::type_index stored_type_;
......@@ -90,4 +122,4 @@ class InputBatchCache {
} // namespace dragnn
} // namespace syntaxnet
#endif // NLP_SAFT_OPENSOURCE_DRAGNN_CORE_INPUT_BATCH_CACHE_H_
#endif // DRAGNN_CORE_INPUT_BATCH_CACHE_H_
......@@ -32,6 +32,8 @@ class StringData : public InputBatch {
}
}
int GetSize() const override { return data_.size(); }
const std::vector<string> GetSerializedData() const override { return data_; }
std::vector<string> *data() { return &data_; }
......@@ -50,6 +52,8 @@ class DifferentStringData : public InputBatch {
}
}
int GetSize() const override { return data_.size(); }
const std::vector<string> GetSerializedData() const override { return data_; }
std::vector<string> *data() { return &data_; }
......@@ -58,6 +62,11 @@ class DifferentStringData : public InputBatch {
std::vector<string> data_;
};
// Expects that two pointers have the same address.
void ExpectSameAddress(const void *pointer1, const void *pointer2) {
EXPECT_EQ(pointer1, pointer2);
}
TEST(InputBatchCacheTest, ConvertsSingleInput) {
string test_string = "Foo";
InputBatchCache generic_set(test_string);
......@@ -118,5 +127,48 @@ TEST(InputBatchCacheTest, ConvertsAddedInputDiesAfterGetAs) {
"after the cache has been converted");
}
TEST(InputBatchCacheTest, SerializedDataAndSize) {
InputBatchCache generic_set;
generic_set.AddData("Foo");
generic_set.AddData("Bar");
generic_set.GetAs<StringData>();
const std::vector<string> expected_data = {"Foo_converted", "Bar_converted"};
EXPECT_EQ(expected_data, generic_set.SerializedData());
EXPECT_EQ(2, generic_set.Size());
}
TEST(InputBatchCacheTest, InitializeFromInputBatch) {
const std::vector<string> kInputData = {"foo", "bar", "baz"};
const std::vector<string> kExpectedData = {"foo_converted", //
"bar_converted", //
"baz_converted"};
std::unique_ptr<StringData> string_data(new StringData());
string_data->SetData(kInputData);
const StringData *string_data_ptr = string_data.get();
InputBatchCache generic_set(std::move(string_data));
auto data = generic_set.GetAs<StringData>();
ExpectSameAddress(string_data_ptr, data);
EXPECT_EQ(data->GetSize(), 3);
EXPECT_EQ(data->GetSerializedData(), kExpectedData);
EXPECT_EQ(*data->data(), kExpectedData);
// AddData() shouldn't work since the cache is already populated.
EXPECT_DEATH(generic_set.AddData("YOU MAY NOT DO THIS AND IT WILL DIE."),
"after the cache has been converted");
// GetAs() shouldn't work with a different type.
EXPECT_DEATH(generic_set.GetAs<DifferentStringData>(),
"Attempted to convert to two object types!");
}
TEST(InputBatchCacheTest, CannotInitializeFromNullInputBatch) {
EXPECT_DEATH(InputBatchCache(std::unique_ptr<StringData>()),
"Cannot initialize from a null InputBatch");
}
} // namespace dragnn
} // namespace syntaxnet
......@@ -13,8 +13,8 @@
// limitations under the License.
// =============================================================================
#ifndef NLP_SAFT_OPENSOURCE_DRAGNN_CORE_INTERFACES_CLONEABLE_TRANSITION_STATE_H_
#define NLP_SAFT_OPENSOURCE_DRAGNN_CORE_INTERFACES_CLONEABLE_TRANSITION_STATE_H_
#ifndef DRAGNN_CORE_INTERFACES_CLONEABLE_TRANSITION_STATE_H_
#define DRAGNN_CORE_INTERFACES_CLONEABLE_TRANSITION_STATE_H_
#include <memory>
#include <vector>
......@@ -33,26 +33,32 @@ class CloneableTransitionState : public TransitionState {
public:
~CloneableTransitionState<T>() override {}
// Initialize this TransitionState from a previous TransitionState. The
// Initializes this TransitionState from a previous TransitionState. The
// ParentBeamIndex is the location of that previous TransitionState in the
// provided beam.
void Init(const TransitionState &parent) override = 0;
// Return the beam index of the state passed into the initializer of this
// Returns the beam index of the state passed into the initializer of this
// TransitionState.
const int ParentBeamIndex() const override = 0;
int ParentBeamIndex() const override = 0;
// Get the current beam index for this state.
const int GetBeamIndex() const override = 0;
// Gets the current beam index for this state.
int GetBeamIndex() const override = 0;
// Set the current beam index for this state.
void SetBeamIndex(const int index) override = 0;
// Sets the current beam index for this state.
void SetBeamIndex(int index) override = 0;
// Get the score associated with this transition state.
const float GetScore() const override = 0;
// Gets the score associated with this transition state.
float GetScore() const override = 0;
// Set the score associated with this transition state.
void SetScore(const float score) override = 0;
// Sets the score associated with this transition state.
void SetScore(float score) override = 0;
// Gets the gold-ness of this state (whether it is on the oracle path)
bool IsGold() const override = 0;
// Sets the gold-ness of this state.
void SetGold(bool is_gold) override = 0;
// Depicts this state as an HTML-language string.
string HTMLRepresentation() const override = 0;
......@@ -64,4 +70,4 @@ class CloneableTransitionState : public TransitionState {
} // namespace dragnn
} // namespace syntaxnet
#endif // NLP_SAFT_OPENSOURCE_DRAGNN_CORE_INTERFACES_CLONEABLE_TRANSITION_STATE_H_
#endif // DRAGNN_CORE_INTERFACES_CLONEABLE_TRANSITION_STATE_H_
......@@ -13,8 +13,8 @@
// limitations under the License.
// =============================================================================
#ifndef NLP_SAFT_OPENSOURCE_DRAGNN_CORE_INTERFACES_COMPONENT_H_
#define NLP_SAFT_OPENSOURCE_DRAGNN_CORE_INTERFACES_COMPONENT_H_
#ifndef DRAGNN_CORE_INTERFACES_COMPONENT_H_
#define DRAGNN_CORE_INTERFACES_COMPONENT_H_
#include <vector>
......@@ -83,11 +83,13 @@ class Component : public RegisterableClass<Component> {
virtual std::function<int(int, int, int)> GetStepLookupFunction(
const string &method) = 0;
// Advances this component from the given transition matrix.
virtual void AdvanceFromPrediction(const float transition_matrix[],
int transition_matrix_length) = 0;
// Advances this component from the given transition matrix, which is
// |num_items| x |num_actions|.
virtual bool AdvanceFromPrediction(const float *score_matrix, int num_items,
int num_actions) = 0;
// Advances this component from the state oracles.
// Advances this component from the state oracles. There is no return from
// this, since it should always succeed.
virtual void AdvanceFromOracle() = 0;
// Returns true if all states within this component are terminal.
......@@ -110,6 +112,14 @@ class Component : public RegisterableClass<Component> {
// BulkFeatureExtractor object to contain the functors and other information.
virtual int BulkGetFixedFeatures(const BulkFeatureExtractor &extractor) = 0;
// Directly computes the embedding matrix for all channels, advancing the
// component via the oracle until it is terminal. This call takes a vector
// of EmbeddingMatrix structs, one per channel, in channel order.
virtual void BulkEmbedFixedFeatures(
int batch_size_padding, int num_steps_padding, int output_array_size,
const vector<const float *> &per_channel_embeddings,
float *embedding_output) = 0;
// Extracts and returns the vector of LinkFeatures for the specified
// channel. Note: these are NOT translated.
virtual std::vector<LinkFeatures> GetRawLinkFeatures(
......@@ -138,4 +148,4 @@ class Component : public RegisterableClass<Component> {
} // namespace dragnn
} // namespace syntaxnet
#endif // NLP_SAFT_OPENSOURCE_DRAGNN_CORE_INTERFACES_COMPONENT_H_
#endif // DRAGNN_CORE_INTERFACES_COMPONENT_H_
......@@ -13,8 +13,8 @@
// limitations under the License.
// =============================================================================
#ifndef NLP_SAFT_OPENSOURCE_DRAGNN_CORE_INTERFACES_INPUT_BATCH_H_
#define NLP_SAFT_OPENSOURCE_DRAGNN_CORE_INTERFACES_INPUT_BATCH_H_
#ifndef DRAGNN_CORE_INTERFACES_INPUT_BATCH_H_
#define DRAGNN_CORE_INTERFACES_INPUT_BATCH_H_
#include <string>
#include <vector>
......@@ -32,14 +32,17 @@ class InputBatch {
public:
virtual ~InputBatch() {}
// Set the data to translate to the subclass' data type.
// Sets the data to translate to the subclass' data type. Call at most once.
virtual void SetData(const std::vector<string> &data) = 0;
// Translate the underlying data back to a vector of strings, as appropriate.
// Returns the size of the batch.
virtual int GetSize() const = 0;
// Translates the underlying data back to a vector of strings, as appropriate.
virtual const std::vector<string> GetSerializedData() const = 0;
};
} // namespace dragnn
} // namespace syntaxnet
#endif // NLP_SAFT_OPENSOURCE_DRAGNN_CORE_INTERFACES_INPUT_BATCH_H_
#endif // DRAGNN_CORE_INTERFACES_INPUT_BATCH_H_
......@@ -13,8 +13,8 @@
// limitations under the License.
// =============================================================================
#ifndef NLP_SAFT_OPENSOURCE_DRAGNN_CORE_INTERFACES_TRANSITION_STATE_H_
#define NLP_SAFT_OPENSOURCE_DRAGNN_CORE_INTERFACES_TRANSITION_STATE_H_
#ifndef DRAGNN_CORE_INTERFACES_TRANSITION_STATE_H_
#define DRAGNN_CORE_INTERFACES_TRANSITION_STATE_H_
#include <memory>
#include <vector>
......@@ -44,19 +44,25 @@ class TransitionState {
// Return the beam index of the state passed into the initializer of this
// TransitionState.
virtual const int ParentBeamIndex() const = 0;
virtual int ParentBeamIndex() const = 0;
// Get the current beam index for this state.
virtual const int GetBeamIndex() const = 0;
// Gets the current beam index for this state.
virtual int GetBeamIndex() const = 0;
// Set the current beam index for this state.
virtual void SetBeamIndex(const int index) = 0;
// Sets the current beam index for this state.
virtual void SetBeamIndex(int index) = 0;
// Get the score associated with this transition state.
virtual const float GetScore() const = 0;
// Gets the score associated with this transition state.
virtual float GetScore() const = 0;
// Set the score associated with this transition state.
virtual void SetScore(const float score) = 0;
// Sets the score associated with this transition state.
virtual void SetScore(float score) = 0;
// Gets the gold-ness of this state (whether it is on the oracle path)
virtual bool IsGold() const = 0;
// Sets the gold-ness of this state.
virtual void SetGold(bool is_gold) = 0;
// Depicts this state as an HTML-language string.
virtual string HTMLRepresentation() const = 0;
......@@ -65,4 +71,4 @@ class TransitionState {
} // namespace dragnn
} // namespace syntaxnet
#endif // NLP_SAFT_OPENSOURCE_DRAGNN_CORE_INTERFACES_TRANSITION_STATE_H_
#endif // DRAGNN_CORE_INTERFACES_TRANSITION_STATE_H_
......@@ -13,8 +13,8 @@
// limitations under the License.
// =============================================================================
#ifndef NLP_SAFT_OPENSOURCE_DRAGNN_CORE_OPS_COMPUTE_SESSION_OP_H_
#define NLP_SAFT_OPENSOURCE_DRAGNN_CORE_OPS_COMPUTE_SESSION_OP_H_
#ifndef DRAGNN_CORE_OPS_COMPUTE_SESSION_OP_H_
#define DRAGNN_CORE_OPS_COMPUTE_SESSION_OP_H_
#include <string>
......@@ -66,4 +66,4 @@ class ComputeSessionOp : public tensorflow::OpKernel {
} // namespace dragnn
} // namespace syntaxnet
#endif // NLP_SAFT_OPENSOURCE_DRAGNN_CORE_OPS_COMPUTE_SESSION_OP_H_
#endif // DRAGNN_CORE_OPS_COMPUTE_SESSION_OP_H_
......@@ -303,6 +303,73 @@ class BulkFixedEmbeddings : public ComputeSessionOp {
REGISTER_KERNEL_BUILDER(Name("BulkFixedEmbeddings").Device(DEVICE_CPU),
BulkFixedEmbeddings);
// See docstring in dragnn_bulk_ops.cc.
class BulkEmbedFixedFeatures : public ComputeSessionOp {
public:
explicit BulkEmbedFixedFeatures(OpKernelConstruction *context)
: ComputeSessionOp(context) {
OP_REQUIRES_OK(context, context->GetAttr("num_channels", &num_channels_));
// The input vector's zeroth element is the state handle, and the remaining
// num_channels_ elements are tensors of float embeddings, one per channel.
vector<DataType> input_types(num_channels_ + 1, DT_FLOAT);
input_types[0] = DT_STRING;
const vector<DataType> output_types = {DT_STRING, DT_FLOAT, DT_INT32};
OP_REQUIRES_OK(context, context->MatchSignature(input_types, output_types));
OP_REQUIRES_OK(context, context->GetAttr("pad_to_batch", &pad_to_batch_));
OP_REQUIRES_OK(context, context->GetAttr("pad_to_steps", &pad_to_steps_));
}
bool OutputsHandle() const override { return true; }
bool RequiresComponentName() const override { return true; }
void ComputeWithState(OpKernelContext *context,
ComputeSession *session) override {
const auto &spec = session->Spec(component_name());
int embedding_size = 0;
std::vector<const float *> embeddings(num_channels_);
for (int channel = 0; channel < num_channels_; ++channel) {
const int embeddings_index = channel + 1;
embedding_size += context->input(embeddings_index).shape().dim_size(1) *
spec.fixed_feature(channel).size();
embeddings[channel] =
context->input(embeddings_index).flat<float>().data();
}
Tensor *embedding_vectors;
OP_REQUIRES_OK(context,
context->allocate_output(
1,
TensorShape({pad_to_steps_ * pad_to_batch_ *
session->BeamSize(component_name()),
embedding_size}),
&embedding_vectors));
Tensor *num_steps_tensor;
OP_REQUIRES_OK(context, context->allocate_output(2, TensorShape({}),
&num_steps_tensor));
embedding_vectors->flat<float>().setZero();
int output_size = embedding_vectors->NumElements();
session->BulkEmbedFixedFeatures(component_name(), pad_to_batch_,
pad_to_steps_, output_size, embeddings,
embedding_vectors->flat<float>().data());
num_steps_tensor->scalar<int32>()() = pad_to_steps_;
}
private:
// Number of fixed feature channels.
int num_channels_;
// Will pad output to this many batch elements.
int pad_to_batch_;
// Will pad output to this many steps.
int pad_to_steps_;
TF_DISALLOW_COPY_AND_ASSIGN(BulkEmbedFixedFeatures);
};
REGISTER_KERNEL_BUILDER(Name("BulkEmbedFixedFeatures").Device(DEVICE_CPU),
BulkEmbedFixedFeatures);
// See docstring in dragnn_bulk_ops.cc.
class BulkAdvanceFromOracle : public ComputeSessionOp {
public:
......@@ -387,8 +454,11 @@ class BulkAdvanceFromPrediction : public ComputeSessionOp {
}
}
if (!session->IsTerminal(component_name())) {
session->AdvanceFromPrediction(component_name(), scores_per_step.data(),
scores_per_step.size());
bool success = session->AdvanceFromPrediction(
component_name(), scores_per_step.data(), num_items, num_actions);
OP_REQUIRES(
context, success,
tensorflow::errors::Internal("Unable to advance from prediction."));
}
}
}
......
......@@ -375,6 +375,114 @@ TEST_F(DragnnBulkOpKernelsTest, BulkFixedEmbeddings) {
EXPECT_EQ(kNumSteps, GetOutput(2)->scalar<int32>()());
}
TEST_F(DragnnBulkOpKernelsTest, BulkEmbedFixedFeatures) {
// Create and initialize the kernel under test.
constexpr int kBatchPad = 7;
constexpr int kStepPad = 5;
constexpr int kMaxBeamSize = 3;
TF_ASSERT_OK(
NodeDefBuilder("BulkEmbedFixedFeatures", "BulkEmbedFixedFeatures")
.Attr("component", kComponentName)
.Attr("num_channels", kNumChannels)
.Attr("pad_to_batch", kBatchPad)
.Attr("pad_to_steps", kStepPad)
.Input(FakeInput(DT_STRING)) // The handle for the ComputeSession.
.Input(FakeInput(DT_FLOAT)) // Embedding matrices.
.Finalize(node_def()));
MockComputeSession *mock_session = GetMockSession();
ComponentSpec spec;
spec.set_name(kComponentName);
auto chan0_spec = spec.add_fixed_feature();
constexpr int kChan0FeatureCount = 2;
chan0_spec->set_size(kChan0FeatureCount);
auto chan1_spec = spec.add_fixed_feature();
constexpr int kChan1FeatureCount = 1;
chan1_spec->set_size(kChan1FeatureCount);
EXPECT_CALL(*mock_session, Spec(kComponentName))
.WillOnce(testing::ReturnRef(spec));
EXPECT_CALL(*mock_session, BeamSize(kComponentName))
.WillOnce(testing::Return(kMaxBeamSize));
// Embedding matrices as additional inputs.
// For channel 0, the embeddings are [id, id, id].
// For channel 1, the embeddings are [id^2, id^2, id^2, ... ,id^2].
vector<float> embedding_matrix_0;
constexpr int kEmbedding0Size = 3;
vector<float> embedding_matrix_1;
constexpr int kEmbedding1Size = 9;
for (int id = 0; id < kNumIds; ++id) {
for (int i = 0; i < kEmbedding0Size; ++i) {
embedding_matrix_0.push_back(id);
LOG(INFO) << embedding_matrix_0.back();
}
for (int i = 0; i < kEmbedding1Size; ++i) {
embedding_matrix_1.push_back(id * id);
LOG(INFO) << embedding_matrix_0.back();
}
}
AddInputFromArray<float>(TensorShape({kNumIds, kEmbedding0Size}),
embedding_matrix_0);
AddInputFromArray<float>(TensorShape({kNumIds, kEmbedding1Size}),
embedding_matrix_1);
constexpr int kExpectedEmbeddingSize = kChan0FeatureCount * kEmbedding0Size +
kChan1FeatureCount * kEmbedding1Size;
constexpr int kExpectedOutputSize =
kExpectedEmbeddingSize * kBatchPad * kStepPad * kMaxBeamSize;
// This function takes the allocator functions passed into GetBulkFF, uses
// them to allocate a tensor, then fills that tensor based on channel.
auto eval_function = [=](const string &component_name, int batch_size_padding,
int num_steps_padding, int output_array_size,
const vector<const float *> &per_channel_embeddings,
float *embedding_output) {
// Validate the control variables.
EXPECT_EQ(batch_size_padding, kBatchPad);
EXPECT_EQ(num_steps_padding, kStepPad);
EXPECT_EQ(output_array_size, kExpectedOutputSize);
// Validate the passed embeddings.
for (int i = 0; i < kNumIds; ++i) {
for (int j = 0; j < kEmbedding0Size; ++j) {
float ch0_embedding =
per_channel_embeddings.at(0)[i * kEmbedding0Size + j];
EXPECT_FLOAT_EQ(ch0_embedding, i)
<< "Failed match at " << i << "," << j;
}
for (int j = 0; j < kEmbedding1Size; ++j) {
float ch1_embedding =
per_channel_embeddings.at(1)[i * kEmbedding1Size + j];
EXPECT_FLOAT_EQ(ch1_embedding, i * i)
<< "Failed match at " << i << "," << j;
}
}
// Fill the output matrix to the expected size. This will trigger msan
// if the allocation wasn't big enough.
for (int i = 0; i < kExpectedOutputSize; ++i) {
embedding_output[i] = i;
}
};
EXPECT_CALL(*mock_session,
BulkEmbedFixedFeatures(kComponentName, _, _, _, _, _))
.WillOnce(testing::Invoke(eval_function));
// Run the kernel.
TF_EXPECT_OK(RunOpKernelWithContext());
// Validate outputs.
EXPECT_EQ(kBatchPad * kStepPad * kMaxBeamSize,
GetOutput(1)->shape().dim_size(0));
EXPECT_EQ(kExpectedEmbeddingSize, GetOutput(1)->shape().dim_size(1));
auto output_data = GetOutput(1)->flat<float>();
for (int i = 0; i < kExpectedOutputSize; ++i) {
EXPECT_FLOAT_EQ(i, output_data(i));
}
EXPECT_EQ(kStepPad, GetOutput(2)->scalar<int32>()());
}
TEST_F(DragnnBulkOpKernelsTest, BulkFixedEmbeddingsWithPadding) {
// Create and initialize the kernel under test.
constexpr int kPaddedNumSteps = 5;
......@@ -592,12 +700,54 @@ TEST_F(DragnnBulkOpKernelsTest, BulkAdvanceFromPrediction) {
EXPECT_CALL(*mock_session,
AdvanceFromPrediction(kComponentName,
CheckScoresAreConsecutiveIntegersDivTen(),
kNumItems * kNumActions))
.Times(kNumSteps);
kNumItems, kNumActions))
.Times(kNumSteps)
.WillRepeatedly(Return(true));
// Run the kernel.
TF_EXPECT_OK(RunOpKernelWithContext());
}
TEST_F(DragnnBulkOpKernelsTest, BulkAdvanceFromPredictionFailsIfAdvanceFails) {
// Create and initialize the kernel under test.
TF_ASSERT_OK(
NodeDefBuilder("BulkAdvanceFromPrediction", "BulkAdvanceFromPrediction")
.Attr("component", kComponentName)
.Input(FakeInput(DT_STRING)) // The handle for the ComputeSession.
.Input(FakeInput(DT_FLOAT)) // Prediction scores for advancing.
.Finalize(node_def()));
MockComputeSession *mock_session = GetMockSession();
// Creates an input tensor such that each step will see a list of consecutive
// integers divided by 10 as scores.
vector<float> scores(kNumItems * kNumSteps * kNumActions);
for (int step(0), cnt(0); step < kNumSteps; ++step) {
for (int item = 0; item < kNumItems; ++item) {
for (int action = 0; action < kNumActions; ++action, ++cnt) {
scores[action + kNumActions * (step + item * kNumSteps)] = cnt / 10.0f;
}
}
}
AddInputFromArray<float>(TensorShape({kNumItems * kNumSteps, kNumActions}),
scores);
EXPECT_CALL(*mock_session, BeamSize(kComponentName)).WillOnce(Return(1));
EXPECT_CALL(*mock_session, BatchSize(kComponentName))
.WillOnce(Return(kNumItems));
EXPECT_CALL(*mock_session, IsTerminal(kComponentName))
.Times(2)
.WillRepeatedly(Return(false));
EXPECT_CALL(*mock_session,
AdvanceFromPrediction(kComponentName,
CheckScoresAreConsecutiveIntegersDivTen(),
kNumItems, kNumActions))
.WillOnce(Return(true))
.WillOnce(Return(false));
// Run the kernel.
auto result = RunOpKernelWithContext();
EXPECT_FALSE(result.ok());
}
} // namespace dragnn
} // namespace syntaxnet
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment