sentence_features.h 21.6 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
/* Copyright 2016 Google Inc. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

// Features that operate on Sentence objects. Most features are defined
// in this header so they may be re-used via composition into other more
// advanced feature classes.

calberti's avatar
calberti committed
20
21
#ifndef SYNTAXNET_SENTENCE_FEATURES_H_
#define SYNTAXNET_SENTENCE_FEATURES_H_
22
23
24
25

#include "syntaxnet/affix.h"
#include "syntaxnet/feature_extractor.h"
#include "syntaxnet/feature_types.h"
26
#include "syntaxnet/segmenter_utils.h"
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
#include "syntaxnet/shared_store.h"
#include "syntaxnet/task_context.h"
#include "syntaxnet/workspace.h"

namespace syntaxnet {

// Feature function for any component that processes Sentences, whose
// focus is a token index into the sentence.
typedef FeatureFunction<Sentence, int> SentenceFeature;

// Alias for Locator type features that take (Sentence, int) signatures
// and call other (Sentence, int) features.
template <class DER>
using Locator = FeatureLocator<DER, Sentence, int>;

class TokenLookupFeature : public SentenceFeature {
 public:
  void Init(TaskContext *context) override {
    set_feature_type(new ResourceBasedFeatureType<TokenLookupFeature>(
        name(), this, {{NumValues(), "<OUTSIDE>"}}));
  }

  // Given a position in a sentence and workspaces, looks up the corresponding
  // feature value. The index is relative to the start of the sentence.
  virtual FeatureValue ComputeValue(const Token &token) const = 0;

  // Number of unique values.
  virtual int64 NumValues() const = 0;

  // Convert the numeric value of the feature to a human readable string.
  virtual string GetFeatureValueName(FeatureValue value) const = 0;

  // Name of the shared workspace.
  virtual string WorkspaceName() const = 0;

  // Runs ComputeValue for each token in the sentence.
  void Preprocess(WorkspaceSet *workspaces,
                  Sentence *sentence) const override {
    if (workspaces->Has<VectorIntWorkspace>(workspace_)) return;
    VectorIntWorkspace *workspace = new VectorIntWorkspace(
        sentence->token_size());
    for (int i = 0; i < sentence->token_size(); ++i) {
      const int value = ComputeValue(sentence->token(i));
      workspace->set_element(i, value);
    }
    workspaces->Set<VectorIntWorkspace>(workspace_, workspace);
  }

  // Requests a vector of int's to store in the workspace registry.
  void RequestWorkspaces(WorkspaceRegistry *registry) override {
    workspace_ = registry->Request<VectorIntWorkspace>(WorkspaceName());
  }

  // Returns the precomputed value, or NumValues() for features outside
  // the sentence.
  FeatureValue Compute(const WorkspaceSet &workspaces,
                       const Sentence &sentence, int focus,
                       const FeatureVector *result) const override {
    if (focus < 0 || focus >= sentence.token_size()) return NumValues();
    return workspaces.Get<VectorIntWorkspace>(workspace_).element(focus);
  }

89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
  int Workspace() const { return workspace_; }

 private:
  int workspace_;
};

// A multi purpose specialization of the feature. Processes the tokens in a
// Sentence by looking up a value set for each token and storing that in
// a VectorVectorInt workspace. Given a set of base values of size Size(),
// reserves an extra value for unknown tokens.
class TokenLookupSetFeature : public SentenceFeature {
 public:
  void Init(TaskContext *context) override {
    set_feature_type(new ResourceBasedFeatureType<TokenLookupSetFeature>(
        name(), this, {{NumValues(), "<OUTSIDE>"}}));
  }

  // Number of unique values.
  virtual int64 NumValues() const = 0;

  // Given a position in a sentence and workspaces, looks up the corresponding
  // feature value set. The index is relative to the start of the sentence.
  virtual void LookupToken(const WorkspaceSet &workspaces,
                           const Sentence &sentence, int index,
113
                           std::vector<int> *values) const = 0;
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139

  // Given a feature value, returns a string representation.
  virtual string GetFeatureValueName(int value) const = 0;

  // Name of the shared workspace.
  virtual string WorkspaceName() const = 0;

  // TokenLookupSetFeatures use VectorVectorIntWorkspaces by default.
  void RequestWorkspaces(WorkspaceRegistry *registry) override {
    workspace_ = registry->Request<VectorVectorIntWorkspace>(WorkspaceName());
  }

  // Default preprocessing: looks up a value set for each token in the Sentence.
  void Preprocess(WorkspaceSet *workspaces, Sentence *sentence) const override {
    // Default preprocessing: lookup a value set for each token in the Sentence.
    if (workspaces->Has<VectorVectorIntWorkspace>(workspace_)) return;
    VectorVectorIntWorkspace *workspace =
        new VectorVectorIntWorkspace(sentence->token_size());
    for (int i = 0; i < sentence->token_size(); ++i) {
      LookupToken(*workspaces, *sentence, i, workspace->mutable_elements(i));
    }
    workspaces->Set<VectorVectorIntWorkspace>(workspace_, workspace);
  }

  // Returns a pre-computed token value from the cache. This assumes the cache
  // is populated.
140
  const std::vector<int> &GetCachedValueSet(const WorkspaceSet &workspaces,
141
142
143
144
145
146
147
148
149
150
151
152
153
154
                                       const Sentence &sentence,
                                       int focus) const {
    // Do bounds checking on focus.
    CHECK_GE(focus, 0);
    CHECK_LT(focus, sentence.token_size());

    // Return value from cache.
    return workspaces.Get<VectorVectorIntWorkspace>(workspace_).elements(focus);
  }

  // Adds any precomputed features at the given focus, if present.
  void Evaluate(const WorkspaceSet &workspaces, const Sentence &sentence,
                int focus, FeatureVector *result) const override {
    if (focus >= 0 && focus < sentence.token_size()) {
155
      const std::vector<int> &elements =
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
          GetCachedValueSet(workspaces, sentence, focus);
      for (auto &value : elements) {
        result->add(this->feature_type(), value);
      }
    }
  }

  // Returns the precomputed value, or NumValues() for features outside
  // the sentence.
  FeatureValue Compute(const WorkspaceSet &workspaces, const Sentence &sentence,
                       int focus, const FeatureVector *result) const override {
    if (focus < 0 || focus >= sentence.token_size()) return NumValues();
    return workspaces.Get<VectorIntWorkspace>(workspace_).element(focus);
  }

171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
 private:
  int workspace_;
};

// Lookup feature that uses a TermFrequencyMap to store a string->int mapping.
class TermFrequencyMapFeature : public TokenLookupFeature {
 public:
  explicit TermFrequencyMapFeature(const string &input_name)
      : input_name_(input_name), min_freq_(0), max_num_terms_(0) {}
  ~TermFrequencyMapFeature() override;

  // Requests the input map as a resource.
  void Setup(TaskContext *context) override;

  // Loads the input map into memory (using SharedStore to avoid redundancy.)
  void Init(TaskContext *context) override;

  // Number of unique values.
calberti's avatar
calberti committed
189
  int64 NumValues() const override { return term_map_->Size() + 1; }
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219

  // Special value for strings not in the map.
  FeatureValue UnknownValue() const { return term_map_->Size(); }

  // Uses the TermFrequencyMap to lookup the string associated with a value.
  string GetFeatureValueName(FeatureValue value) const override;

  // Name of the shared workspace.
  string WorkspaceName() const override;

 protected:
  const TermFrequencyMap &term_map() const { return *term_map_; }

 private:
  // Shortcut pointer to shared map. Not owned.
  const TermFrequencyMap *term_map_ = nullptr;

  // Name of the input for the term map.
  string input_name_;

  // Filename of the underlying resource.
  string file_name_;

  // Minimum frequency for term map.
  int min_freq_;

  // Maximum number of terms for term map.
  int max_num_terms_;
};

220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
// Specialization of the TokenLookupSetFeature class to use a TermFrequencyMap
// to perform the mapping. This takes two options: "min_freq" (discard tokens
// with less than this min frequency), and "max_num_terms" (only read in at most
// these terms.)
class TermFrequencyMapSetFeature : public TokenLookupSetFeature {
 public:
  // Initializes with an empty name, since we need the options to compute the
  // actual workspace name.
  explicit TermFrequencyMapSetFeature(const string &input_name)
      : input_name_(input_name), min_freq_(0), max_num_terms_(0) {}

  // Releases shared resources.
  ~TermFrequencyMapSetFeature() override;

  // Returns index of raw word text.
  virtual void GetTokenIndices(const Token &token,
236
                               std::vector<int> *values) const = 0;
237
238
239
240
241
242
243
244
245

  // Requests the resource inputs.
  void Setup(TaskContext *context) override;

  // Obtains resources using the shared store. At this point options are known
  // so the full name can be computed.
  void Init(TaskContext *context) override;

  // Number of unique values.
246

247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
  int64 NumValues() const override { return term_map_->Size(); }

  // Special value for strings not in the map.
  FeatureValue UnknownValue() const { return term_map_->Size(); }

  // Gets pointer to the underlying map.
  const TermFrequencyMap *term_map() const { return term_map_; }

  // Returns the term index or the unknown value. Used inside GetTokenIndex()
  // specializations for convenience.
  int LookupIndex(const string &term) const {
    return term_map_->LookupIndex(term, -1);
  }

  // Given a position in a sentence and workspaces, looks up the corresponding
  // feature value set. The index is relative to the start of the sentence.
  void LookupToken(const WorkspaceSet &workspaces, const Sentence &sentence,
264
                   int index, std::vector<int> *values) const override {
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
    GetTokenIndices(sentence.token(index), values);
  }

  // Uses the TermFrequencyMap to lookup the string associated with a value.
  string GetFeatureValueName(int value) const override {
    if (value == UnknownValue()) return "<UNKNOWN>";
    if (value >= 0 && value < NumValues()) {
      return term_map_->GetTerm(value);
    }
    LOG(ERROR) << "Invalid feature value: " << value;
    return "<INVALID>";
  }

  // Name of the shared workspace.
  string WorkspaceName() const override;

 private:
  // Shortcut pointer to shared map. Not owned.
  const TermFrequencyMap *term_map_ = nullptr;

  // Name of the input for the term map.
  string input_name_;

  // Filename of the underlying resource.
  string file_name_;

  // Minimum frequency for term map.
  int min_freq_;

  // Maximum number of terms for term map.
  int max_num_terms_;
};

298
299
300
301
302
class Word : public TermFrequencyMapFeature {
 public:
  Word() : TermFrequencyMapFeature("word-map") {}

  FeatureValue ComputeValue(const Token &token) const override {
303
    const string &form = token.word();
304
305
306
307
    return term_map().LookupIndex(form, UnknownValue());
  }
};

308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
class Char : public TermFrequencyMapFeature {
 public:
  Char() : TermFrequencyMapFeature("char-map") {}

  FeatureValue ComputeValue(const Token &token) const override {
    const string &form = token.word();
    if (SegmenterUtils::IsBreakChar(form)) return BreakCharValue();
    return term_map().LookupIndex(form, UnknownValue());
  }

  // Special value for breaks.
  FeatureValue BreakCharValue() const { return term_map().Size(); }

  // Special value for non-break strings not in the map.
  FeatureValue UnknownValue() const { return term_map().Size() + 1; }

  // Number of unique values.
  int64 NumValues() const override { return term_map().Size() + 2; }

  string GetFeatureValueName(FeatureValue value) const override {
    if (value == BreakCharValue()) return "<BREAK_CHAR>";
    if (value == UnknownValue()) return "<UNKNOWN>";
    if (value >= 0 && value < term_map().Size()) {
      return term_map().GetTerm(value);
    }
    LOG(ERROR) << "Invalid feature value: " << value;
    return "<INVALID>";
  }
};

338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
class LowercaseWord : public TermFrequencyMapFeature {
 public:
  LowercaseWord() : TermFrequencyMapFeature("lc-word-map") {}

  FeatureValue ComputeValue(const Token &token) const override {
    const string lcword = utils::Lowercase(token.word());
    return term_map().LookupIndex(lcword, UnknownValue());
  }
};

class Tag : public TermFrequencyMapFeature {
 public:
  Tag() : TermFrequencyMapFeature("tag-map") {}

  FeatureValue ComputeValue(const Token &token) const override {
    return term_map().LookupIndex(token.tag(), UnknownValue());
  }
};

class Label : public TermFrequencyMapFeature {
 public:
  Label() : TermFrequencyMapFeature("label-map") {}

  FeatureValue ComputeValue(const Token &token) const override {
    return term_map().LookupIndex(token.label(), UnknownValue());
  }
};

366
367
368
369
370
371
372
373
374
375
376
377
378
class CharNgram : public TermFrequencyMapSetFeature {
 public:
  CharNgram() : TermFrequencyMapSetFeature("char-ngram-map") {}
  ~CharNgram() override {}

  void Setup(TaskContext *context) override {
    TermFrequencyMapSetFeature::Setup(context);
    max_char_ngram_length_ = context->Get("lexicon_max_char_ngram_length", 3);
    use_terminators_ =
        context->Get("lexicon_char_ngram_include_terminators", false);
  }

  // Returns index of raw word text.
379
380
  void GetTokenIndices(const Token &token,
                       std::vector<int> *values) const override;
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404

 private:
  // Size parameter (n) for the ngrams.
  int max_char_ngram_length_ = 3;

  // Whether to pad the word with ^ and $ before extracting ngrams.
  bool use_terminators_ = false;
};

class MorphologySet : public TermFrequencyMapSetFeature {
 public:
  MorphologySet() : TermFrequencyMapSetFeature("morphology-map") {}
  ~MorphologySet() override {}

  void Setup(TaskContext *context) override {
    TermFrequencyMapSetFeature::Setup(context);
  }


  int64 NumValues() const override {
    return term_map()->Size() - 1;
  }

  // Returns index of raw word text.
405
406
  void GetTokenIndices(const Token &token,
                       std::vector<int> *values) const override;
407
408
};

409
410
411
412
413
414
415
416
class LexicalCategoryFeature : public TokenLookupFeature {
 public:
  LexicalCategoryFeature(const string &name, int cardinality)
      : name_(name), cardinality_(cardinality) {}
  ~LexicalCategoryFeature() override {}

  FeatureValue NumValues() const override { return cardinality_; }

417
  // Returns the identifier for the workspace for this feature.
418
419
420
421
422
423
424
425
426
427
428
429
  string WorkspaceName() const override {
    return tensorflow::strings::StrCat(name_, ":", cardinality_);
  }

 private:
  // Name of the category type.
  const string name_;

  // Number of values.
  const int cardinality_;
};

430
// Feature that computes whether a word has a hyphen or not.
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
class Hyphen : public LexicalCategoryFeature {
 public:
  // Enumeration of values.
  enum Category {
    NO_HYPHEN = 0,
    HAS_HYPHEN = 1,
    CARDINALITY = 2,
  };

  // Default constructor.
  Hyphen() : LexicalCategoryFeature("hyphen", CARDINALITY) {}

  // Returns a string representation of the enum value.
  string GetFeatureValueName(FeatureValue value) const override;

  // Returns the category value for the token.
  FeatureValue ComputeValue(const Token &token) const override;
};

450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
// Feature that categorizes the capitalization of the word. If the option
// utf8=true is specified, lowercase and uppercase checks are done with UTF8
// compliant functions.
class Capitalization : public LexicalCategoryFeature {
 public:
  // Enumeration of values.
  enum Category {
    LOWERCASE = 0,                     // normal word
    UPPERCASE = 1,                     // all-caps
    CAPITALIZED = 2,                   // has one cap and one non-cap
    CAPITALIZED_SENTENCE_INITIAL = 3,  // same as above but sentence-initial
    NON_ALPHABETIC = 4,                // contains no alphabetic characters
    CARDINALITY = 5,
  };

  // Default constructor.
  Capitalization() : LexicalCategoryFeature("capitalization", CARDINALITY) {}

  // Sets one of the options for the capitalization.
  void Setup(TaskContext *context) override;

  // Capitalization needs special preprocessing because token category can
  // depend on whether the token is at the start of the sentence.
  void Preprocess(WorkspaceSet *workspaces, Sentence *sentence) const override;

  // Returns a string representation of the enum value.
  string GetFeatureValueName(FeatureValue value) const override;

  // Returns the category value for the token.
  FeatureValue ComputeValue(const Token &token) const override {
    LOG(FATAL) << "Capitalization should use ComputeValueWithFocus.";
    return 0;
  }

  // Returns the category value for the token.
  FeatureValue ComputeValueWithFocus(const Token &token, int focus) const;

 private:
  // Whether to use UTF8 compliant functions to check capitalization.
  bool utf8_ = false;
};

// A feature for computing whether the focus token contains any punctuation
// for ternary features.
class PunctuationAmount : public LexicalCategoryFeature {
 public:
  // Enumeration of values.
  enum Category {
    NO_PUNCTUATION = 0,
    SOME_PUNCTUATION = 1,
    ALL_PUNCTUATION = 2,
    CARDINALITY = 3,
  };

  // Default constructor.
  PunctuationAmount()
      : LexicalCategoryFeature("punctuation-amount", CARDINALITY) {}

  // Returns a string representation of the enum value.
  string GetFeatureValueName(FeatureValue value) const override;

  // Returns the category value for the token.
  FeatureValue ComputeValue(const Token &token) const override;
};

// A feature for a feature that returns whether the word is an open or
// close quotation mark, based on its relative position to other quotation marks
// in the sentence.
class Quote : public LexicalCategoryFeature {
 public:
  // Enumeration of values.
  enum Category {
    NO_QUOTE = 0,
    OPEN_QUOTE = 1,
    CLOSE_QUOTE = 2,
    UNKNOWN_QUOTE = 3,
    CARDINALITY = 4,
  };

  // Default constructor.
  Quote() : LexicalCategoryFeature("quote", CARDINALITY) {}

  // Returns a string representation of the enum value.
  string GetFeatureValueName(FeatureValue value) const override;

  // Returns the category value for the token.
  FeatureValue ComputeValue(const Token &token) const override;

  // Override preprocess to compute open and close quotes from prior context of
  // the sentence.
  void Preprocess(WorkspaceSet *workspaces, Sentence *instance) const override;
};

// Feature that computes whether a word has digits or not.
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
class Digit : public LexicalCategoryFeature {
 public:
  // Enumeration of values.
  enum Category {
    NO_DIGIT = 0,
    SOME_DIGIT = 1,
    ALL_DIGIT = 2,
    CARDINALITY = 3,
  };

  // Default constructor.
  Digit() : LexicalCategoryFeature("digit", CARDINALITY) {}

  // Returns a string representation of the enum value.
  string GetFeatureValueName(FeatureValue value) const override;

  // Returns the category value for the token.
  FeatureValue ComputeValue(const Token &token) const override;
};

564
// TokenLookupFeature object to compute prefixes and suffixes of words. The
565
// AffixTable is stored in the SharedStore. This is very similar to the
566
// implementation of TermFrequencyMapFeature, but using an AffixTable to
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
// perform the lookups. There are only two specializations, for prefixes and
// suffixes.
class AffixTableFeature : public TokenLookupFeature {
 public:
  // Explicit constructor to set the type of the table. This determines the
  // requested input.
  explicit AffixTableFeature(AffixTable::Type type);
  ~AffixTableFeature() override;

  // Requests inputs for the affix table.
  void Setup(TaskContext *context) override;

  // Loads the affix table from the SharedStore.
  void Init(TaskContext *context) override;

  // The workspace name is specific to which affix length we are computing.
  string WorkspaceName() const override;

  // Returns the total number of affixes in the table, regardless of specified
  // length.
  FeatureValue NumValues() const override { return affix_table_->size() + 1; }

  // Special value for strings not in the map.
  FeatureValue UnknownValue() const { return affix_table_->size(); }

  // Looks up the affix for a given word.
  FeatureValue ComputeValue(const Token &token) const override;

  // Returns the string associated with a value.
  string GetFeatureValueName(FeatureValue value) const override;

 private:
  // Size parameter for the affix table.
  int affix_length_;

  // Name of the input for the table.
  string input_name_;

  // The type of the affix table.
  const AffixTable::Type type_;

  // Affix table used for indexing. This comes from the shared store, and is not
  // owned directly.
  const AffixTable *affix_table_ = nullptr;
};

// Specific instantiation for computing prefixes. This requires the input
// "prefix-table".
class PrefixFeature : public AffixTableFeature {
 public:
  PrefixFeature() : AffixTableFeature(AffixTable::PREFIX) {}
};

// Specific instantiation for computing suffixes. Requires the input
// "suffix-table."
class SuffixFeature : public AffixTableFeature {
 public:
  SuffixFeature() : AffixTableFeature(AffixTable::SUFFIX) {}
};

// Offset locator. Simple locator: just changes the focus by some offset.
class Offset : public Locator<Offset> {
 public:
  void UpdateArgs(const WorkspaceSet &workspaces,
                  const Sentence &sentence, int *focus) const {
    *focus += argument();
  }
};

typedef FeatureExtractor<Sentence, int> SentenceExtractor;

// Utility to register the sentence_instance::Feature functions.
#define REGISTER_SENTENCE_IDX_FEATURE(name, type) \
640
  REGISTER_SYNTAXNET_FEATURE_FUNCTION(SentenceFeature, name, type)
641
642
643

}  // namespace syntaxnet

calberti's avatar
calberti committed
644
#endif  // SYNTAXNET_SENTENCE_FEATURES_H_