creating_components.md 8.01 KB
Newer Older
Ivan Bogatyy's avatar
Ivan Bogatyy committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
# Creating a DRAGNN Component

[TOC]

## Why Create A Component?

A Component is the logic that performs actions based on Tensorflow inference
results. Because DRAGNN allows downstream components to access data from
Components that have already run - both during training and during inference -
wrapping your inference-using logic in a Component can allow it to be trained
and used with inferred data as inputs and have downstream units use its data as
an input during training (and inference). This doesn't even need to be a
linguistic task - any neural-net computation could be wrapped as a Component.

In addition, Components can be created to very efficiently perform a small set
of actions - for instance, computation with a Beam can be expensive, so if your
application does not need a Beam, you could create a component without one and
save its computational overhead.

## The Component Interface

All Components must implement the Component interface (located in
dragnn/core/interfaces/component.h). Of special note are the translator access
functions, which must be implemented for downstream components to correctly
examine the component's data.

These functions are:

```
int StepsTaken(int batch_index)
```

This function should return the number of steps taken by the Component as it has
operated on batch index "batch_index". Since each Component can operate on many
data items in parallel, some of which may become final earlier than others, it
is necessary to specify the batch index here. Also note that batch indices are
assumed to be constant throughout the DRAGNN system and between components -
batch i is the computation that corresponds to element i of the vector of input
data, always.

```
int GetBeamIndexAtStep(int step, int current_index, int batch)
```

This function should look up the TransitionState that is currently at element
'index' in the state's beam for batch element 'batch', then determine where that
element was in that batch element's beam at step 'step'. If it's out of bounds,
then return -1.

```
int GetSourceBeamIndex(int current_index, int batch)
```

This function should return the 'source beam index' of the TransitionState
currently at element 'current index' for batch element 'batch'. To find the
source beam index, first determine what beam index the element was in at time 0
(at initialization). Then, determine the element *from the previous component*
was used to initialize that element; that element's index is the source beam
index. (This is used to maintain beam history throughout DRAGNN).

### When Is A Component Terminal?

One of the important concepts for a Components is the idea of "being terminal".
A Component is terminal when all of its batch elements are completely analyzed-
that is, no computation is left to perform for any element. For a parser, for
instance, this occurs when all tokens have been examined.

Components must always become terminal after some number of steps; if not,
DRAGNN will become caught in an infinite loop. The number of steps does not have
to be deterministic, however.

### Defining Component Input

A DRAGNN graph is fed with strings, and each Component determines how to
interpret them. For example, the SyntaxNetComponent expects each string to be a
serialized Sentence protocol buffer. (If you want to operate on Sentence
objects, it probably makes sense to add more `syntaxnet::TransitionSystem`
classes than write a new Component from scratch.)

If you want to to read other types of data, you will need to sub-class
InputBatchCache, a container which holds the strings and deserializes them into
typed data on demand. Usage is fairly straightforward; see the SyntaxNet
implementation for how to use it.

**Note: all Component's in the graph should use a single InputBatchCache
sub-class. If you need multiple data types, you'll want to have a single
InputBatchCache that has fields for each of your data types.**

## Using Transition States

Each Component is expected to intialize itself from a beam of TransitionStates
(which may be empty, if the component is the first one in a DRAGNN computation)
and is expected to emit a set of TransitionStates when it is complete (via a
call to GetBeam). There is no requirement that TransitionStates be used
internally, but it does make things easier (for one thing, if you use
CloneableTransitionStates, you can use the provided Beam class to track your
component history and beam state).

### Initializing From TransitionStates

When your component is initialized, it will receive a vector of vectors of
TransitionStates. The external vector is the batch index and the internal vector
is in beam state order - so you should use the internal vector at index i to
initialize the beam for batch element's i. When initializing new
TransitionStates from the internal vector's states, be sure to record the index
of that element in the input beam somewhere; you'll need to be able to return it
for the GetSourceBeamIndex function. If you don't, translations won't work
properly.

### Finalizing & Emitting TransitionState Data

When your component is terminal, the ComputeSession will call FinalizeData() and
GetBeam() in that order. These steps "lock" the component computation and return
pointers (note the use of raw pointers - components always retain ownership of
their own transition states!) to the final state of the component as captured in
its TransitionStates.

When data is finalized, the best (in our case, the highest scoring, but you may
have a different metric) result is written to the underlying data that was
passed to the component. Writing the result back to the data will ensure that
components later in the pipeline can use that result in their computations.

Emitting data is straightforward: create a vector of vector of pointers to the
beam states in your component, and return it.

### Basic Testing

If you would like to validate that your TransitionState meets the DRAGNN
contract expectations, you can use the transition_state_starter_test.cc, and
adapt it to use your transition state. If all tests pass, your TransitionState
should work with the rest of DRAGNN.

## Creating Translators

One of the key features of DRAGNN is "translators" - functions that allow
components executing later in a DRAGNN pipeline to access data from earlier
components. There are two types of translators - universal and backend-specific.

### Maintaining History For Universal Translators

In order to support translation, your component must be able to report history
via the Component interface. (This is required whether you want to directly
support translation in your component or not - in order for translation to work
at all in the DRAGNN pipeline, all components must implement these methods).

### Defining Backend Specific Translators

If you would like downstream components to be able to access your component's
data in a more complex manner than the universal translators allow, you can
define your own translation functions. This is done via the
`GetStepLookupFunction` method in the component, which returns an arbitrary
function when given a string name. To create a translator with a specific name,
make the `GetStepLookupFunction` method return it when queried; once done, that
name can be specified as a LinkedFeature for components downstream.

Note that, when defining a backend-specific translator, the arguments to the
returned std::function are (int batch_index, int beam_index, int value), or,
"for the transition state at the given beam_index of the beam corresponding to
batch batch_index, return a function mapping value to a location in the input
tensor for this step."

## Beams & The Beam Class

If you would like a helper function to keep track of beams and history for your
component, you can use core/beam.h - make sure your TransitionState implements
the CloneableTransitionState interface, instantiate a beam (templated to your
TransitionState type) for each batch element, and you're good to go - all you
need to do is plumb the relevant methods from the Component interface to the
Beam. (For an example, you can examine syntaxnet_component.cc, which uses a Beam
of SyntaxNetTransitionState objects.