Commit d66941ac authored by Ivan Bogatyy's avatar Ivan Bogatyy Committed by calberti
Browse files

Sync w TF r0.12 & Bazel 0.4.3, internal updates (#953)

parent efa4a6cf
...@@ -73,8 +73,8 @@ class SentenceFeaturesTest : public ::testing::Test { ...@@ -73,8 +73,8 @@ class SentenceFeaturesTest : public ::testing::Test {
// Extracts a vector of string representations from evaluating the prepared // Extracts a vector of string representations from evaluating the prepared
// set feature (returning multiple values) at the given index. // set feature (returning multiple values) at the given index.
virtual vector<string> ExtractMultiFeature(int index) { virtual std::vector<string> ExtractMultiFeature(int index) {
vector<string> values; std::vector<string> values;
FeatureVector result; FeatureVector result;
extractor_->ExtractFeatures(workspaces_, sentence_, index, extractor_->ExtractFeatures(workspaces_, sentence_, index,
&result); &result);
...@@ -97,8 +97,8 @@ class SentenceFeaturesTest : public ::testing::Test { ...@@ -97,8 +97,8 @@ class SentenceFeaturesTest : public ::testing::Test {
// Checks that a vector workspace is equal to a target vector. // Checks that a vector workspace is equal to a target vector.
void CheckVectorWorkspace(const VectorIntWorkspace &workspace, void CheckVectorWorkspace(const VectorIntWorkspace &workspace,
vector<int> target) { std::vector<int> target) {
vector<int> src; std::vector<int> src;
for (int i = 0; i < workspace.size(); ++i) { for (int i = 0; i < workspace.size(); ++i) {
src.push_back(workspace.element(i)); src.push_back(workspace.element(i));
} }
......
...@@ -36,7 +36,8 @@ def AddCrossEntropy(batch_size, n): ...@@ -36,7 +36,8 @@ def AddCrossEntropy(batch_size, n):
return tf.constant(0, dtype=tf.float32, shape=[1]) return tf.constant(0, dtype=tf.float32, shape=[1])
for beam_id in range(batch_size): for beam_id in range(batch_size):
beam_gold_slot = tf.reshape(tf.slice(n['gold_slot'], [beam_id], [1]), [1]) beam_gold_slot = tf.reshape(
tf.strided_slice(n['gold_slot'], [beam_id], [beam_id + 1], [1]), [1])
def _ComputeCrossEntropy(): def _ComputeCrossEntropy():
"""Adds ops to compute cross entropy of the gold path in a beam.""" """Adds ops to compute cross entropy of the gold path in a beam."""
# Requires a cast so that UnsortedSegmentSum, in the gradient, # Requires a cast so that UnsortedSegmentSum, in the gradient,
...@@ -48,8 +49,9 @@ def AddCrossEntropy(batch_size, n): ...@@ -48,8 +49,9 @@ def AddCrossEntropy(batch_size, n):
beam_scores = tf.reshape(tf.gather(n['all_path_scores'], idx), [1, -1]) beam_scores = tf.reshape(tf.gather(n['all_path_scores'], idx), [1, -1])
num = tf.shape(idx) num = tf.shape(idx)
return tf.nn.softmax_cross_entropy_with_logits( return tf.nn.softmax_cross_entropy_with_logits(
beam_scores, tf.expand_dims( labels=tf.expand_dims(
tf.sparse_to_dense(beam_gold_slot, num, [1.], 0.), 0)) tf.sparse_to_dense(beam_gold_slot, num, [1.], 0.), 0),
logits=beam_scores)
# The conditional here is needed to deal with the last few batches of the # The conditional here is needed to deal with the last few batches of the
# corpus which can contain -1 in beam_gold_slot for empty batch slots. # corpus which can contain -1 in beam_gold_slot for empty batch slots.
cross_entropies.append(cf.cond( cross_entropies.append(cf.cond(
......
...@@ -128,10 +128,10 @@ class TaggerTransitionState : public ParserTransitionState { ...@@ -128,10 +128,10 @@ class TaggerTransitionState : public ParserTransitionState {
private: private:
// Currently assigned POS tags for each token in this sentence. // Currently assigned POS tags for each token in this sentence.
vector<int> tag_; std::vector<int> tag_;
// Gold POS tags from the input document. // Gold POS tags from the input document.
vector<int> gold_tag_; std::vector<int> gold_tag_;
// Tag map used for conversions between integer and string representations // Tag map used for conversions between integer and string representations
// part of speech tags. Not owned. // part of speech tags. Not owned.
......
...@@ -72,7 +72,7 @@ class TaskContext { ...@@ -72,7 +72,7 @@ class TaskContext {
// Vector of parameters required by this task. These must be specified in the // Vector of parameters required by this task. These must be specified in the
// task rather than relying on default values. // task rather than relying on default values.
vector<string> required_parameters_; std::vector<string> required_parameters_;
}; };
} // namespace syntaxnet } // namespace syntaxnet
......
...@@ -32,7 +32,7 @@ int TermFrequencyMap::Increment(const string &term) { ...@@ -32,7 +32,7 @@ int TermFrequencyMap::Increment(const string &term) {
const TermIndex::const_iterator it = term_index_.find(term); const TermIndex::const_iterator it = term_index_.find(term);
if (term_index_.find(term) != term_index_.end()) { if (term_index_.find(term) != term_index_.end()) {
// Increment the existing term. // Increment the existing term.
pair<string, int64> &data = term_data_[it->second]; std::pair<string, int64> &data = term_data_[it->second];
CHECK_EQ(term, data.first); CHECK_EQ(term, data.first);
++(data.second); ++(data.second);
return it->second; return it->second;
...@@ -41,7 +41,7 @@ int TermFrequencyMap::Increment(const string &term) { ...@@ -41,7 +41,7 @@ int TermFrequencyMap::Increment(const string &term) {
const int index = term_index_.size(); const int index = term_index_.size();
CHECK_LT(index, std::numeric_limits<int32>::max()); // overflow CHECK_LT(index, std::numeric_limits<int32>::max()); // overflow
term_index_[term] = index; term_index_[term] = index;
term_data_.push_back(pair<string, int64>(term, 1)); term_data_.push_back(std::pair<string, int64>(term, 1));
return index; return index;
} }
} }
...@@ -74,7 +74,7 @@ void TermFrequencyMap::Load(const string &filename, int min_frequency, ...@@ -74,7 +74,7 @@ void TermFrequencyMap::Load(const string &filename, int min_frequency,
int64 last_frequency = -1; int64 last_frequency = -1;
for (int i = 0; i < total && i < max_num_terms; ++i) { for (int i = 0; i < total && i < max_num_terms; ++i) {
TF_CHECK_OK(buffer.ReadLine(&line)); TF_CHECK_OK(buffer.ReadLine(&line));
vector<string> elements = utils::Split(line, ' '); std::vector<string> elements = utils::Split(line, ' ');
CHECK_EQ(2, elements.size()); CHECK_EQ(2, elements.size());
CHECK(!elements[0].empty()); CHECK(!elements[0].empty());
CHECK(!elements[1].empty()); CHECK(!elements[1].empty());
...@@ -97,7 +97,7 @@ void TermFrequencyMap::Load(const string &filename, int min_frequency, ...@@ -97,7 +97,7 @@ void TermFrequencyMap::Load(const string &filename, int min_frequency,
// Assign the next available index. // Assign the next available index.
const int index = term_index_.size(); const int index = term_index_.size();
term_index_[term] = index; term_index_[term] = index;
term_data_.push_back(pair<string, int64>(term, frequency)); term_data_.push_back(std::pair<string, int64>(term, frequency));
} }
CHECK_EQ(term_index_.size(), term_data_.size()); CHECK_EQ(term_index_.size(), term_data_.size());
LOG(INFO) << "Loaded " << term_index_.size() << " terms from " << filename LOG(INFO) << "Loaded " << term_index_.size() << " terms from " << filename
...@@ -107,8 +107,8 @@ void TermFrequencyMap::Load(const string &filename, int min_frequency, ...@@ -107,8 +107,8 @@ void TermFrequencyMap::Load(const string &filename, int min_frequency,
struct TermFrequencyMap::SortByFrequencyThenTerm { struct TermFrequencyMap::SortByFrequencyThenTerm {
// Return a > b to sort in descending order of frequency; otherwise, // Return a > b to sort in descending order of frequency; otherwise,
// lexicographic sort on term. // lexicographic sort on term.
bool operator()(const pair<string, int64> &a, bool operator()(const std::pair<string, int64> &a,
const pair<string, int64> &b) const { const std::pair<string, int64> &b) const {
return (a.second > b.second || (a.second == b.second && a.first < b.first)); return (a.second > b.second || (a.second == b.second && a.first < b.first));
} }
}; };
...@@ -117,7 +117,7 @@ void TermFrequencyMap::Save(const string &filename) const { ...@@ -117,7 +117,7 @@ void TermFrequencyMap::Save(const string &filename) const {
CHECK_EQ(term_index_.size(), term_data_.size()); CHECK_EQ(term_index_.size(), term_data_.size());
// Copy and sort the term data. // Copy and sort the term data.
vector<pair<string, int64>> sorted_data(term_data_); std::vector<std::pair<string, int64>> sorted_data(term_data_);
std::sort(sorted_data.begin(), sorted_data.end(), SortByFrequencyThenTerm()); std::sort(sorted_data.begin(), sorted_data.end(), SortByFrequencyThenTerm());
// Write the number of terms. // Write the number of terms.
...@@ -149,7 +149,7 @@ TagToCategoryMap::TagToCategoryMap(const string &filename) { ...@@ -149,7 +149,7 @@ TagToCategoryMap::TagToCategoryMap(const string &filename) {
tensorflow::io::BufferedInputStream buffer(&stream, kInputBufferSize); tensorflow::io::BufferedInputStream buffer(&stream, kInputBufferSize);
string line; string line;
while (buffer.ReadLine(&line) == tensorflow::Status::OK()) { while (buffer.ReadLine(&line) == tensorflow::Status::OK()) {
vector<string> pair = utils::Split(line, '\t'); std::vector<string> pair = utils::Split(line, '\t');
CHECK(line.empty() || pair.size() == 2) << line; CHECK(line.empty() || pair.size() == 2) << line;
tag_to_category_[pair[0]] = pair[1]; tag_to_category_[pair[0]] = pair[1];
} }
......
...@@ -83,7 +83,7 @@ class TermFrequencyMap { ...@@ -83,7 +83,7 @@ class TermFrequencyMap {
TermIndex term_index_; TermIndex term_index_;
// Mapping from indices to term and frequency. // Mapping from indices to term and frequency.
vector<pair<string, int64>> term_data_; std::vector<std::pair<string, int64>> term_data_;
TF_DISALLOW_COPY_AND_ASSIGN(TermFrequencyMap); TF_DISALLOW_COPY_AND_ASSIGN(TermFrequencyMap);
}; };
...@@ -107,7 +107,7 @@ class TagToCategoryMap { ...@@ -107,7 +107,7 @@ class TagToCategoryMap {
void Save(const string &filename) const; void Save(const string &filename) const;
private: private:
map<string, string> tag_to_category_; std::map<string, string> tag_to_category_;
TF_DISALLOW_COPY_AND_ASSIGN(TagToCategoryMap); TF_DISALLOW_COPY_AND_ASSIGN(TagToCategoryMap);
}; };
......
...@@ -83,16 +83,16 @@ class CoNLLSyntaxFormat : public DocumentFormat { ...@@ -83,16 +83,16 @@ class CoNLLSyntaxFormat : public DocumentFormat {
} }
void ConvertFromString(const string &key, const string &value, void ConvertFromString(const string &key, const string &value,
vector<Sentence *> *sentences) override { std::vector<Sentence *> *sentences) override {
// Create new sentence. // Create new sentence.
Sentence *sentence = new Sentence(); Sentence *sentence = new Sentence();
// Each line corresponds to one token. // Each line corresponds to one token.
string text; string text;
vector<string> lines = utils::Split(value, '\n'); std::vector<string> lines = utils::Split(value, '\n');
// Add each token to the sentence. // Add each token to the sentence.
vector<string> fields; std::vector<string> fields;
int expected_id = 1; int expected_id = 1;
for (size_t i = 0; i < lines.size(); ++i) { for (size_t i = 0; i < lines.size(); ++i) {
// Split line into tab-separated fields. // Split line into tab-separated fields.
...@@ -166,12 +166,12 @@ class CoNLLSyntaxFormat : public DocumentFormat { ...@@ -166,12 +166,12 @@ class CoNLLSyntaxFormat : public DocumentFormat {
void ConvertToString(const Sentence &sentence, string *key, void ConvertToString(const Sentence &sentence, string *key,
string *value) override { string *value) override {
*key = sentence.docid(); *key = sentence.docid();
vector<string> lines; std::vector<string> lines;
for (int i = 0; i < sentence.token_size(); ++i) { for (int i = 0; i < sentence.token_size(); ++i) {
Token token = sentence.token(i); Token token = sentence.token(i);
if (join_category_to_pos_) SplitCategoryFromPos(&token); if (join_category_to_pos_) SplitCategoryFromPos(&token);
if (add_pos_as_attribute_) RemovePosFromAttributes(&token); if (add_pos_as_attribute_) RemovePosFromAttributes(&token);
vector<string> fields(10); std::vector<string> fields(10);
fields[0] = tensorflow::strings::Printf("%d", i + 1); fields[0] = tensorflow::strings::Printf("%d", i + 1);
fields[1] = UnderscoreIfEmpty(token.word()); fields[1] = UnderscoreIfEmpty(token.word());
fields[2] = "_"; fields[2] = "_";
...@@ -198,14 +198,14 @@ class CoNLLSyntaxFormat : public DocumentFormat { ...@@ -198,14 +198,14 @@ class CoNLLSyntaxFormat : public DocumentFormat {
void AddMorphAttributes(const string &attributes, Token *token) { void AddMorphAttributes(const string &attributes, Token *token) {
TokenMorphology *morph = TokenMorphology *morph =
token->MutableExtension(TokenMorphology::morphology); token->MutableExtension(TokenMorphology::morphology);
vector<string> att_vals = utils::Split(attributes, '|'); std::vector<string> att_vals = utils::Split(attributes, '|');
for (int i = 0; i < att_vals.size(); ++i) { for (int i = 0; i < att_vals.size(); ++i) {
vector<string> att_val = utils::SplitOne(att_vals[i], '='); std::vector<string> att_val = utils::SplitOne(att_vals[i], '=');
// Format is either: // Format is either:
// 1) a1=v1|a2=v2..., e.g., Czech CoNLL data, or, // 1) a1=v1|a2=v2..., e.g., Czech CoNLL data, or,
// 2) v1|v2|..., e.g., German CoNLL data. // 2) v1|v2|..., e.g., German CoNLL data.
const pair<string, string> name_value = const std::pair<string, string> name_value =
att_val.size() == 2 ? std::make_pair(att_val[0], att_val[1]) att_val.size() == 2 ? std::make_pair(att_val[0], att_val[1])
: std::make_pair(att_val[0], "on"); : std::make_pair(att_val[0], "on");
...@@ -282,7 +282,7 @@ class CoNLLSyntaxFormat : public DocumentFormat { ...@@ -282,7 +282,7 @@ class CoNLLSyntaxFormat : public DocumentFormat {
TF_DISALLOW_COPY_AND_ASSIGN(CoNLLSyntaxFormat); TF_DISALLOW_COPY_AND_ASSIGN(CoNLLSyntaxFormat);
}; };
REGISTER_DOCUMENT_FORMAT("conll-sentence", CoNLLSyntaxFormat); REGISTER_SYNTAXNET_DOCUMENT_FORMAT("conll-sentence", CoNLLSyntaxFormat);
// Reader for segmentation training data format. This reader assumes the input // Reader for segmentation training data format. This reader assumes the input
// format is similar to CoNLL format but with only two fileds: // format is similar to CoNLL format but with only two fileds:
...@@ -325,16 +325,16 @@ class SegmentationTrainingDataFormat : public CoNLLSyntaxFormat { ...@@ -325,16 +325,16 @@ class SegmentationTrainingDataFormat : public CoNLLSyntaxFormat {
// to SPACE_BREAK to indicate that the corresponding gold transition for that // to SPACE_BREAK to indicate that the corresponding gold transition for that
// character token is START. Otherwise NO_BREAK to indicate MERGE. // character token is START. Otherwise NO_BREAK to indicate MERGE.
void ConvertFromString(const string &key, const string &value, void ConvertFromString(const string &key, const string &value,
vector<Sentence *> *sentences) override { std::vector<Sentence *> *sentences) override {
// Create new sentence. // Create new sentence.
Sentence *sentence = new Sentence(); Sentence *sentence = new Sentence();
// Each line corresponds to one token. // Each line corresponds to one token.
string text; string text;
vector<string> lines = utils::Split(value, '\n'); std::vector<string> lines = utils::Split(value, '\n');
// Add each token to the sentence. // Add each token to the sentence.
vector<string> fields; std::vector<string> fields;
for (size_t i = 0; i < lines.size(); ++i) { for (size_t i = 0; i < lines.size(); ++i) {
// Split line into tab-separated fields. // Split line into tab-separated fields.
fields.clear(); fields.clear();
...@@ -362,7 +362,7 @@ class SegmentationTrainingDataFormat : public CoNLLSyntaxFormat { ...@@ -362,7 +362,7 @@ class SegmentationTrainingDataFormat : public CoNLLSyntaxFormat {
} }
// Add character-based token to sentence. // Add character-based token to sentence.
vector<tensorflow::StringPiece> chars; std::vector<tensorflow::StringPiece> chars;
SegmenterUtils::GetUTF8Chars(word, &chars); SegmenterUtils::GetUTF8Chars(word, &chars);
bool is_first_char = true; bool is_first_char = true;
for (auto utf8char : chars) { for (auto utf8char : chars) {
...@@ -398,7 +398,8 @@ class SegmentationTrainingDataFormat : public CoNLLSyntaxFormat { ...@@ -398,7 +398,8 @@ class SegmentationTrainingDataFormat : public CoNLLSyntaxFormat {
} }
}; };
REGISTER_DOCUMENT_FORMAT("segment-train-data", SegmentationTrainingDataFormat); REGISTER_SYNTAXNET_DOCUMENT_FORMAT("segment-train-data",
SegmentationTrainingDataFormat);
// Reader for tokenized text. This reader expects every sentence to be on a // Reader for tokenized text. This reader expects every sentence to be on a
// single line and tokens on that line to be separated by single spaces. // single line and tokens on that line to be separated by single spaces.
...@@ -414,7 +415,7 @@ class TokenizedTextFormat : public DocumentFormat { ...@@ -414,7 +415,7 @@ class TokenizedTextFormat : public DocumentFormat {
} }
void ConvertFromString(const string &key, const string &value, void ConvertFromString(const string &key, const string &value,
vector<Sentence *> *sentences) override { std::vector<Sentence *> *sentences) override {
Sentence *sentence = new Sentence(); Sentence *sentence = new Sentence();
string text; string text;
for (const string &word : utils::Split(value, ' ')) { for (const string &word : utils::Split(value, ' ')) {
...@@ -463,7 +464,7 @@ class TokenizedTextFormat : public DocumentFormat { ...@@ -463,7 +464,7 @@ class TokenizedTextFormat : public DocumentFormat {
TF_DISALLOW_COPY_AND_ASSIGN(TokenizedTextFormat); TF_DISALLOW_COPY_AND_ASSIGN(TokenizedTextFormat);
}; };
REGISTER_DOCUMENT_FORMAT("tokenized-text", TokenizedTextFormat); REGISTER_SYNTAXNET_DOCUMENT_FORMAT("tokenized-text", TokenizedTextFormat);
// Reader for un-tokenized text. This reader expects every sentence to be on a // Reader for un-tokenized text. This reader expects every sentence to be on a
// single line. For each line in the input, a sentence proto will be created, // single line. For each line in the input, a sentence proto will be created,
...@@ -474,9 +475,9 @@ class UntokenizedTextFormat : public TokenizedTextFormat { ...@@ -474,9 +475,9 @@ class UntokenizedTextFormat : public TokenizedTextFormat {
UntokenizedTextFormat() {} UntokenizedTextFormat() {}
void ConvertFromString(const string &key, const string &value, void ConvertFromString(const string &key, const string &value,
vector<Sentence *> *sentences) override { std::vector<Sentence *> *sentences) override {
Sentence *sentence = new Sentence(); Sentence *sentence = new Sentence();
vector<tensorflow::StringPiece> chars; std::vector<tensorflow::StringPiece> chars;
SegmenterUtils::GetUTF8Chars(value, &chars); SegmenterUtils::GetUTF8Chars(value, &chars);
int start = 0; int start = 0;
for (auto utf8char : chars) { for (auto utf8char : chars) {
...@@ -502,7 +503,7 @@ class UntokenizedTextFormat : public TokenizedTextFormat { ...@@ -502,7 +503,7 @@ class UntokenizedTextFormat : public TokenizedTextFormat {
TF_DISALLOW_COPY_AND_ASSIGN(UntokenizedTextFormat); TF_DISALLOW_COPY_AND_ASSIGN(UntokenizedTextFormat);
}; };
REGISTER_DOCUMENT_FORMAT("untokenized-text", UntokenizedTextFormat); REGISTER_SYNTAXNET_DOCUMENT_FORMAT("untokenized-text", UntokenizedTextFormat);
// Text reader that attmpts to perform Penn Treebank tokenization on arbitrary // Text reader that attmpts to perform Penn Treebank tokenization on arbitrary
// raw text. Adapted from https://www.cis.upenn.edu/~treebank/tokenizer.sed // raw text. Adapted from https://www.cis.upenn.edu/~treebank/tokenizer.sed
...@@ -514,8 +515,8 @@ class EnglishTextFormat : public TokenizedTextFormat { ...@@ -514,8 +515,8 @@ class EnglishTextFormat : public TokenizedTextFormat {
EnglishTextFormat() {} EnglishTextFormat() {}
void ConvertFromString(const string &key, const string &value, void ConvertFromString(const string &key, const string &value,
vector<Sentence *> *sentences) override { std::vector<Sentence *> *sentences) override {
vector<pair<string, string>> preproc_rules = { std::vector<std::pair<string, string>> preproc_rules = {
// Punctuation. // Punctuation.
{"’", "'"}, {"’", "'"},
{"…", "..."}, {"…", "..."},
...@@ -570,7 +571,7 @@ class EnglishTextFormat : public TokenizedTextFormat { ...@@ -570,7 +571,7 @@ class EnglishTextFormat : public TokenizedTextFormat {
{"♦", ""}, {"♦", ""},
}; };
vector<pair<string, string>> rules = { std::vector<std::pair<string, string>> rules = {
// attempt to get correct directional quotes // attempt to get correct directional quotes
{R"re(^")re", "`` "}, {R"re(^")re", "`` "},
{R"re(([ \([{<])")re", "\\1 `` "}, {R"re(([ \([{<])")re", "\\1 `` "},
...@@ -639,10 +640,10 @@ class EnglishTextFormat : public TokenizedTextFormat { ...@@ -639,10 +640,10 @@ class EnglishTextFormat : public TokenizedTextFormat {
}; };
string rewritten = value; string rewritten = value;
for (const pair<string, string> &rule : preproc_rules) { for (const std::pair<string, string> &rule : preproc_rules) {
RE2::GlobalReplace(&rewritten, rule.first, rule.second); RE2::GlobalReplace(&rewritten, rule.first, rule.second);
} }
for (const pair<string, string> &rule : rules) { for (const std::pair<string, string> &rule : rules) {
RE2::GlobalReplace(&rewritten, rule.first, rule.second); RE2::GlobalReplace(&rewritten, rule.first, rule.second);
} }
TokenizedTextFormat::ConvertFromString(key, rewritten, sentences); TokenizedTextFormat::ConvertFromString(key, rewritten, sentences);
...@@ -652,6 +653,6 @@ class EnglishTextFormat : public TokenizedTextFormat { ...@@ -652,6 +653,6 @@ class EnglishTextFormat : public TokenizedTextFormat {
TF_DISALLOW_COPY_AND_ASSIGN(EnglishTextFormat); TF_DISALLOW_COPY_AND_ASSIGN(EnglishTextFormat);
}; };
REGISTER_DOCUMENT_FORMAT("english-text", EnglishTextFormat); REGISTER_SYNTAXNET_DOCUMENT_FORMAT("english-text", EnglishTextFormat);
} // namespace syntaxnet } // namespace syntaxnet
...@@ -37,7 +37,7 @@ VectorIntWorkspace::VectorIntWorkspace(int size) : elements_(size) {} ...@@ -37,7 +37,7 @@ VectorIntWorkspace::VectorIntWorkspace(int size) : elements_(size) {}
VectorIntWorkspace::VectorIntWorkspace(int size, int value) VectorIntWorkspace::VectorIntWorkspace(int size, int value)
: elements_(size, value) {} : elements_(size, value) {}
VectorIntWorkspace::VectorIntWorkspace(const vector<int> &elements) VectorIntWorkspace::VectorIntWorkspace(const std::vector<int> &elements)
: elements_(elements) {} : elements_(elements) {}
string VectorIntWorkspace::TypeName() { return "Vector"; } string VectorIntWorkspace::TypeName() { return "Vector"; }
......
...@@ -57,7 +57,7 @@ class WorkspaceRegistry { ...@@ -57,7 +57,7 @@ class WorkspaceRegistry {
int Request(const string &name) { int Request(const string &name) {
const std::type_index id = std::type_index(typeid(W)); const std::type_index id = std::type_index(typeid(W));
workspace_types_[id] = W::TypeName(); workspace_types_[id] = W::TypeName();
vector<string> &names = workspace_names_[id]; std::vector<string> &names = workspace_names_[id];
for (int i = 0; i < names.size(); ++i) { for (int i = 0; i < names.size(); ++i) {
if (names[i] == name) return i; if (names[i] == name) return i;
} }
...@@ -65,8 +65,8 @@ class WorkspaceRegistry { ...@@ -65,8 +65,8 @@ class WorkspaceRegistry {
return names.size() - 1; return names.size() - 1;
} }
const std::unordered_map<std::type_index, vector<string> > &WorkspaceNames() const std::unordered_map<std::type_index, std::vector<string> >
const { &WorkspaceNames() const {
return workspace_names_; return workspace_names_;
} }
...@@ -78,7 +78,7 @@ class WorkspaceRegistry { ...@@ -78,7 +78,7 @@ class WorkspaceRegistry {
std::unordered_map<std::type_index, string> workspace_types_; std::unordered_map<std::type_index, string> workspace_types_;
// Workspace names, indexed as workspace_names_[typeid][workspace]. // Workspace names, indexed as workspace_names_[typeid][workspace].
std::unordered_map<std::type_index, vector<string> > workspace_names_; std::unordered_map<std::type_index, std::vector<string> > workspace_names_;
TF_DISALLOW_COPY_AND_ASSIGN(WorkspaceRegistry); TF_DISALLOW_COPY_AND_ASSIGN(WorkspaceRegistry);
}; };
...@@ -137,7 +137,7 @@ class WorkspaceSet { ...@@ -137,7 +137,7 @@ class WorkspaceSet {
private: private:
// The set of workspaces, indexed as workspaces_[typeid][index]. // The set of workspaces, indexed as workspaces_[typeid][index].
std::unordered_map<std::type_index, vector<Workspace *> > workspaces_; std::unordered_map<std::type_index, std::vector<Workspace *> > workspaces_;
}; };
// A workspace that wraps around a single int. // A workspace that wraps around a single int.
...@@ -170,7 +170,7 @@ class VectorIntWorkspace : public Workspace { ...@@ -170,7 +170,7 @@ class VectorIntWorkspace : public Workspace {
explicit VectorIntWorkspace(int size); explicit VectorIntWorkspace(int size);
// Creates a vector initialized with the given array. // Creates a vector initialized with the given array.
explicit VectorIntWorkspace(const vector<int> &elements); explicit VectorIntWorkspace(const std::vector<int> &elements);
// Creates a vector of the given size, with each element initialized to the // Creates a vector of the given size, with each element initialized to the
// given value. // given value.
...@@ -189,7 +189,7 @@ class VectorIntWorkspace : public Workspace { ...@@ -189,7 +189,7 @@ class VectorIntWorkspace : public Workspace {
private: private:
// The enclosed vector. // The enclosed vector.
vector<int> elements_; std::vector<int> elements_;
}; };
// A workspace that wraps around a vector of vector of int. // A workspace that wraps around a vector of vector of int.
...@@ -202,14 +202,14 @@ class VectorVectorIntWorkspace : public Workspace { ...@@ -202,14 +202,14 @@ class VectorVectorIntWorkspace : public Workspace {
static string TypeName(); static string TypeName();
// Returns the i'th vector of elements. // Returns the i'th vector of elements.
const vector<int> &elements(int i) const { return elements_[i]; } const std::vector<int> &elements(int i) const { return elements_[i]; }
// Mutable access to the i'th vector of elements. // Mutable access to the i'th vector of elements.
vector<int> *mutable_elements(int i) { return &(elements_[i]); } std::vector<int> *mutable_elements(int i) { return &(elements_[i]); }
private: private:
// The enclosed vector of vector of elements. // The enclosed vector of vector of elements.
vector<vector<int> > elements_; std::vector<std::vector<int> > elements_;
}; };
} // namespace syntaxnet } // namespace syntaxnet
......
Subproject commit aab099711d7e04034cf742ddb9b00dd15edbe99c Subproject commit 45ab528211c962b19e12f6b77165848310271624
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment