sequence_model.h 5.3 KB
Newer Older
Terry Koo's avatar
Terry Koo committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================

#ifndef DRAGNN_RUNTIME_SEQUENCE_MODEL_H_
#define DRAGNN_RUNTIME_SEQUENCE_MODEL_H_

#include <stddef.h>
#include <memory>
#include <string>

#include "dragnn/core/compute_session.h"
#include "dragnn/core/input_batch_cache.h"
#include "dragnn/protos/spec.pb.h"
#include "dragnn/runtime/fixed_embeddings.h"
#include "dragnn/runtime/linked_embeddings.h"
#include "dragnn/runtime/network_states.h"
#include "dragnn/runtime/sequence_features.h"
#include "dragnn/runtime/sequence_links.h"
#include "dragnn/runtime/sequence_predictor.h"
#include "dragnn/runtime/session_state.h"
#include "syntaxnet/base.h"
#include "tensorflow/core/lib/core/status.h"

namespace syntaxnet {
namespace dragnn {
namespace runtime {

// A class that configures and helps evaluate a sequence-based model.
//
// This class requires the SequenceBackend component backend and elides most of
// the ComputeSession feature extraction and transition system overhead.
class SequenceModel {
 public:
  // State associated with a single evaluation of the model.
  struct EvaluateState {
    // Number of transition steps in the current sequence.
    size_t num_steps = 0;

    // Current input batch.
    InputBatchCache *input = nullptr;

    // Sequence-based fixed features.
    SequenceFeatures features;

    // Sequence-based linked embeddings.
    SequenceLinks links;
  };

  // Creates an uninitialized model.  Call Initialize() before use.
  SequenceModel() = default;

  // Returns true if the |component_spec| is compatible with a sequence model.
  static bool Supports(const ComponentSpec &component_spec);

  // Initalizes this from the configuration in the |component_spec|.  Wraps the
  // |fixed_embedding_manager| and |linked_embedding_manager| in sequence-based
  // versions, and requests layers from the |network_state_manager|.  All of the
  // managers must outlive this.  If the transition system is non-deterministic,
  // uses the layer named |logits_name| to make predictions later in Predict();
  // otherwise, |logits_name| is ignored and Predict() does nothing.  On error,
  // returns non-OK.
  tensorflow::Status Initialize(
      const ComponentSpec &component_spec, const string &logits_name,
      const FixedEmbeddingManager *fixed_embedding_manager,
      const LinkedEmbeddingManager *linked_embedding_manager,
      NetworkStateManager *network_state_manager);

  // Resets the |evaluate_state| to values derived from the |session_state| and
  // |compute_session|.  Also updates the NetworkStates in the |session_state|
  // and the current component of the |compute_session| with the length of the
  // current sequence.  Call this before producing output layers.  On error,
  // returns non-OK.
  tensorflow::Status Preprocess(SessionState *session_state,
                                ComputeSession *compute_session,
                                EvaluateState *evaluate_state) const;

  // If applicable, makes predictions based on the logits in |network_states|
  // and applies them to the input in the |evaluate_state|.  Call this after
  // producing output layers.  On error, returns non-OK.
  tensorflow::Status Predict(const NetworkStates &network_states,
                             EvaluateState *evaluate_state) const;

  // Accessors.
  bool deterministic() const { return deterministic_; }
  bool left_to_right() const { return left_to_right_; }
  const SequenceLinkManager &sequence_link_manager() const;
  const SequenceFeatureManager &sequence_feature_manager() const;

 private:
  // Name of the component that this model is a part of.
  string component_name_;

  // Whether the underlying transition system is deterministic.
  bool deterministic_ = false;

  // Whether to process sequences from left to right.
  bool left_to_right_ = true;

  // Whether fixed or linked features are present.
  bool have_fixed_features_ = false;
  bool have_linked_features_ = false;

  // Handle to the logits layer.  Only used if |deterministic_| is false.
  LayerHandle<float> logits_handle_;

  // Manager for sequence-based feature extractors.
  SequenceFeatureManager sequence_feature_manager_;

  // Manager for sequence-based linked embeddings.
  SequenceLinkManager sequence_link_manager_;

  // Sequence-based predictor, if |deterministic_| is false.
  std::unique_ptr<SequencePredictor> sequence_predictor_;
};

// Implementation details below.

inline const SequenceLinkManager &SequenceModel::sequence_link_manager() const {
  return sequence_link_manager_;
}

inline const SequenceFeatureManager &SequenceModel::sequence_feature_manager()
    const {
  return sequence_feature_manager_;
}

}  // namespace runtime
}  // namespace dragnn
}  // namespace syntaxnet

#endif  // DRAGNN_RUNTIME_SEQUENCE_MODEL_H_