bulk_component.py 22.1 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
# Copyright 2017 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
Ivan Bogatyy's avatar
Ivan Bogatyy committed
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
"""Component builders for non-recurrent networks in DRAGNN."""


import tensorflow as tf
from tensorflow.python.platform import tf_logging as logging

from dragnn.python import component
from dragnn.python import dragnn_ops
from dragnn.python import network_units
from syntaxnet.util import check


def fetch_linked_embedding(comp, network_states, feature_spec):
  """Looks up linked embeddings in other components.

  Args:
    comp: ComponentBuilder object with respect to which the feature is to be
        fetched
    network_states: dictionary of NetworkState objects
    feature_spec: FeatureSpec proto for the linked feature to be looked up

  Returns:
    NamedTensor containing the linked feature tensor

  Raises:
    NotImplementedError: if a linked feature with source translator other than
        'identity' is configured.
    RuntimeError: if a recurrent linked feature is configured.
  """
  if feature_spec.source_translator != 'identity':
    raise NotImplementedError(feature_spec.source_translator)
  if feature_spec.source_component == comp.name:
    raise RuntimeError(
        'Recurrent linked features are not supported in bulk extraction.')
  tf.logging.info('[%s] Adding linked feature "%s"', comp.name,
                  feature_spec.name)
  source = comp.master.lookup_component[feature_spec.source_component]

Terry Koo's avatar
Terry Koo committed
53
54
  return network_units.NamedTensor(network_states[source.name].activations[
      feature_spec.source_layer].bulk_tensor, feature_spec.name)
Ivan Bogatyy's avatar
Ivan Bogatyy committed
55
56
57
58
59
60
61
62


def _validate_embedded_fixed_features(comp):
  """Checks that the embedded fixed features of |comp| are set up properly."""
  for feature in comp.spec.fixed_feature:
    check.Gt(feature.embedding_dim, 0,
             'Embeddings requested for non-embedded feature: %s' % feature)
    if feature.is_constant:
Terry Koo's avatar
Terry Koo committed
63
64
65
      check.IsTrue(
          feature.HasField('pretrained_embedding_matrix'),
          'Constant embeddings must be pretrained: %s' % feature)
Ivan Bogatyy's avatar
Ivan Bogatyy committed
66
67


Terry Koo's avatar
Terry Koo committed
68
def fetch_differentiable_fixed_embeddings(comp, state, stride, during_training):
Ivan Bogatyy's avatar
Ivan Bogatyy committed
69
70
71
72
73
74
  """Looks up fixed features with separate, differentiable, embedding lookup.

  Args:
    comp: Component whose fixed features we wish to look up.
    state: live MasterState object for the component.
    stride: Tensor containing current batch * beam size.
Terry Koo's avatar
Terry Koo committed
75
76
    during_training: True if this is being called from a training code path.
      This controls, e.g., the use of feature ID dropout.
Ivan Bogatyy's avatar
Ivan Bogatyy committed
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95

  Returns:
    state handle: updated state handle to be used after this call
    fixed_embeddings: list of NamedTensor objects
  """
  _validate_embedded_fixed_features(comp)
  num_channels = len(comp.spec.fixed_feature)
  if not num_channels:
    return state.handle, []

  state.handle, indices, ids, weights, num_steps = (
      dragnn_ops.bulk_fixed_features(
          state.handle, component=comp.name, num_channels=num_channels))
  fixed_embeddings = []
  for channel, feature_spec in enumerate(comp.spec.fixed_feature):
    differentiable_or_constant = ('constant' if feature_spec.is_constant else
                                  'differentiable')
    tf.logging.info('[%s] Adding %s fixed feature "%s"', comp.name,
                    differentiable_or_constant, feature_spec.name)
Terry Koo's avatar
Terry Koo committed
96
97
98
99
100

    if during_training and feature_spec.dropout_id >= 0:
      ids[channel], weights[channel] = network_units.apply_feature_id_dropout(
          ids[channel], weights[channel], feature_spec)

Ivan Bogatyy's avatar
Ivan Bogatyy committed
101
102
103
104
105
106
107
108
109
110
111
112
    size = stride * num_steps * feature_spec.size
    fixed_embedding = network_units.embedding_lookup(
        comp.get_variable(network_units.fixed_embeddings_name(channel)),
        indices[channel], ids[channel], weights[channel], size)
    if feature_spec.is_constant:
      fixed_embedding = tf.stop_gradient(fixed_embedding)
    fixed_embeddings.append(
        network_units.NamedTensor(fixed_embedding, feature_spec.name))

  return state.handle, fixed_embeddings


Terry Koo's avatar
Terry Koo committed
113
114
115
116
def fetch_fast_fixed_embeddings(comp,
                                state,
                                pad_to_batch=None,
                                pad_to_steps=None):
Ivan Bogatyy's avatar
Ivan Bogatyy committed
117
118
119
120
  """Looks up fixed features with fast, non-differentiable, op.

  Since BulkFixedEmbeddings is non-differentiable with respect to the
  embeddings, the idea is to call this function only when the graph is
Terry Koo's avatar
Terry Koo committed
121
122
  not being used for training. If the function is being called with fixed step
  and batch sizes, it will use the most efficient possible extractor.
Ivan Bogatyy's avatar
Ivan Bogatyy committed
123
124
125
126

  Args:
    comp: Component whose fixed features we wish to look up.
    state: live MasterState object for the component.
Terry Koo's avatar
Terry Koo committed
127
128
    pad_to_batch: Optional; the number of batch elements to pad to.
    pad_to_steps: Optional; the number of steps to pad to.
Ivan Bogatyy's avatar
Ivan Bogatyy committed
129
130
131
132
133
134
135
136
137
138
139

  Returns:
    state handle: updated state handle to be used after this call
    fixed_embeddings: list of NamedTensor objects
  """
  _validate_embedded_fixed_features(comp)
  num_channels = len(comp.spec.fixed_feature)
  if not num_channels:
    return state.handle, []
  tf.logging.info('[%s] Adding %d fast fixed features', comp.name, num_channels)

Terry Koo's avatar
Terry Koo committed
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
  features = [
      comp.get_variable(network_units.fixed_embeddings_name(c))
      for c in range(num_channels)
  ]

  if pad_to_batch is not None and pad_to_steps is not None:
    # If we have fixed padding numbers, we can use 'bulk_embed_fixed_features',
    # which is the fastest embedding extractor.
    state.handle, bulk_embeddings, _ = dragnn_ops.bulk_embed_fixed_features(
        state.handle,
        features,
        component=comp.name,
        pad_to_batch=pad_to_batch,
        pad_to_steps=pad_to_steps)
  else:
    state.handle, bulk_embeddings, _ = dragnn_ops.bulk_fixed_embeddings(
        state.handle, features, component=comp.name)

  bulk_embeddings = network_units.NamedTensor(
      bulk_embeddings, 'bulk-%s-fixed-features' % comp.name)
Ivan Bogatyy's avatar
Ivan Bogatyy committed
160
161
162
  return state.handle, [bulk_embeddings]


Terry Koo's avatar
Terry Koo committed
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
def fetch_dense_ragged_embeddings(comp, state):
  """Gets embeddings in RaggedTensor format."""
  _validate_embedded_fixed_features(comp)
  num_channels = len(comp.spec.fixed_feature)
  if not num_channels:
    return state.handle, []
  tf.logging.info('[%s] Adding %d fast fixed features', comp.name, num_channels)

  features = [
      comp.get_variable(network_units.fixed_embeddings_name(c))
      for c in range(num_channels)
  ]

  state.handle, data, offsets = dragnn_ops.bulk_embed_dense_fixed_features(
      state.handle, features, component=comp.name)

  data = network_units.NamedTensor(data, 'dense-%s-data' % comp.name)
  offsets = network_units.NamedTensor(offsets, 'dense-%s-offsets' % comp.name)
  return state.handle, [data, offsets]


Ivan Bogatyy's avatar
Ivan Bogatyy committed
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
def extract_fixed_feature_ids(comp, state, stride):
  """Extracts fixed feature IDs.

  Args:
    comp: Component whose fixed feature IDs we wish to extract.
    state: Live MasterState object for the component.
    stride: Tensor containing current batch * beam size.

  Returns:
    state handle: Updated state handle to be used after this call.
    ids: List of [stride * num_steps, 1] feature IDs per channel.  Missing IDs
         (e.g., due to batch padding) are set to -1.
  """
  num_channels = len(comp.spec.fixed_feature)
  if not num_channels:
    return state.handle, []

  for feature_spec in comp.spec.fixed_feature:
    check.Eq(feature_spec.size, 1, 'All features must have size=1')
    check.Lt(feature_spec.embedding_dim, 0, 'All features must be non-embedded')

  state.handle, indices, ids, _, num_steps = dragnn_ops.bulk_fixed_features(
      state.handle, component=comp.name, num_channels=num_channels)
  size = stride * num_steps

  fixed_ids = []
  for channel, feature_spec in enumerate(comp.spec.fixed_feature):
    tf.logging.info('[%s] Adding fixed feature IDs "%s"', comp.name,
                    feature_spec.name)

    # The +1 and -1 increments ensure that missing IDs default to -1.
    #
    # TODO(googleuser): This formula breaks if multiple IDs are extracted at some
    # step.  Try using tf.unique() to enforce the unique-IDS precondition.
    sums = tf.unsorted_segment_sum(ids[channel] + 1, indices[channel], size) - 1
    sums = tf.expand_dims(sums, axis=1)
    fixed_ids.append(network_units.NamedTensor(sums, feature_spec.name, dim=1))
  return state.handle, fixed_ids


def update_network_states(comp, tensors, network_states, stride):
  """Stores Tensor objects corresponding to layer outputs.

  For use in subsequent tasks.

  Args:
    comp: Component for which the tensor handles are being stored.
    tensors: list of Tensors to store
    network_states: dictionary of component NetworkState objects
    stride: stride of the stored tensor.
  """
  network_state = network_states[comp.name]
  with tf.name_scope(comp.name + '/stored_act'):
    for index, network_tensor in enumerate(tensors):
      network_state.activations[comp.network.layers[index].name] = (
Terry Koo's avatar
Terry Koo committed
239
240
241
242
          network_units.StoredActivations(
              tensor=network_tensor,
              stride=stride,
              dim=comp.network.layers[index].dim))
Ivan Bogatyy's avatar
Ivan Bogatyy committed
243
244
245
246
247
248
249
250
251


def build_cross_entropy_loss(logits, gold):
  """Constructs a cross entropy from logits and one-hot encoded gold labels.

  Supports skipping rows where the gold label is the magic -1 value.

  Args:
    logits: float Tensor of scores.
Terry Koo's avatar
Terry Koo committed
252
    gold: int Tensor of gold label ids.
Ivan Bogatyy's avatar
Ivan Bogatyy committed
253
254
255
256
257
258
259
260
261
262

  Returns:
    cost, correct, total: the total cost, the total number of correctly
        predicted labels, and the total number of valid labels.
  """
  valid = tf.reshape(tf.where(tf.greater(gold, -1)), [-1])
  gold = tf.gather(gold, valid)
  logits = tf.gather(logits, valid)
  correct = tf.reduce_sum(tf.to_int32(tf.nn.in_top_k(logits, gold, 1)))
  total = tf.size(gold)
263
264
265
266
267
  with tf.control_dependencies([tf.assert_positive(total)]):
    cost = tf.reduce_sum(
        tf.nn.sparse_softmax_cross_entropy_with_logits(
            labels=tf.cast(gold, tf.int64), logits=logits)) / tf.cast(
                total, tf.float32)
Ivan Bogatyy's avatar
Ivan Bogatyy committed
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
  return cost, correct, total


class BulkFeatureExtractorComponentBuilder(component.ComponentBuilderBase):
  """A component builder to bulk extract features.

  Both fixed and linked features are supported, with some restrictions:
  1. Fixed features may not be recurrent. Fixed features are extracted along the
     gold path, which does not work during inference.
  2. Linked features may not be recurrent and are 'untranslated'. For now,
     linked features are extracted without passing them through any transition
     system or source translator.
  """

  def build_greedy_training(self, state, network_states):
    """Extracts features and advances a batch using the oracle path.

    Args:
      state: MasterState from the 'AdvanceMaster' op that advances the
          underlying master to this component.
      network_states: dictionary of component NetworkState objects

    Returns:
      state handle: final state after advancing
      cost: regularization cost, possibly associated with embedding matrices
      correct: since no gold path is available, 0.
      total: since no gold path is available, 0.
    """
    logging.info('Building component: %s', self.spec.name)
    stride = state.current_batch_size * self.training_beam_size
Terry Koo's avatar
Terry Koo committed
298
    self.network.pre_create(stride)
Ivan Bogatyy's avatar
Ivan Bogatyy committed
299
300
    with tf.variable_scope(self.name, reuse=True):
      state.handle, fixed_embeddings = fetch_differentiable_fixed_embeddings(
Terry Koo's avatar
Terry Koo committed
301
          self, state, stride, True)
Ivan Bogatyy's avatar
Ivan Bogatyy committed
302
303
304
305
306
307
308
309
310
311
312
313

    linked_embeddings = [
        fetch_linked_embedding(self, network_states, spec)
        for spec in self.spec.linked_feature
    ]

    with tf.variable_scope(self.name, reuse=True):
      tensors = self.network.create(
          fixed_embeddings, linked_embeddings, None, None, True, stride=stride)
    update_network_states(self, tensors, network_states, stride)
    cost = self.add_regularizer(tf.constant(0.))

314
315
    correct, total = tf.constant(0), tf.constant(0)
    return state.handle, cost, correct, total
Ivan Bogatyy's avatar
Ivan Bogatyy committed
316

317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
  def build_post_restore_hook(self):
    """Builds a graph that should be executed after the restore op.

    This graph is intended to be run once, before the inference pipeline is
    run.

    Returns:
      setup_op - An op that, when run, guarantees all setup ops will run.
    """
    logging.info('Building restore hook for component: %s', self.spec.name)
    with tf.variable_scope(self.name):
      if callable(getattr(self.network, 'build_post_restore_hook', None)):
        return [self.network.build_post_restore_hook()]
      else:
        return []

Ivan Bogatyy's avatar
Ivan Bogatyy committed
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
  def build_greedy_inference(self, state, network_states,
                             during_training=False):
    """Extracts features and advances a batch using the oracle path.

    NOTE(danielandor) For now this method cannot be called during training.
    That is to say, unroll_using_oracle for this component must be set to true.
    This will be fixed by separating train_with_oracle and train_with_inference.

    Args:
      state: MasterState from the 'AdvanceMaster' op that advances the
          underlying master to this component.
      network_states: dictionary of component NetworkState objects
      during_training: whether the graph is being constructed during training

    Returns:
      state handle: final state after advancing
    """
    logging.info('Building component: %s', self.spec.name)
    if during_training:
      stride = state.current_batch_size * self.training_beam_size
    else:
      stride = state.current_batch_size * self.inference_beam_size
Terry Koo's avatar
Terry Koo committed
355
    self.network.pre_create(stride)
Ivan Bogatyy's avatar
Ivan Bogatyy committed
356
357
358
359

    with tf.variable_scope(self.name, reuse=True):
      if during_training:
        state.handle, fixed_embeddings = fetch_differentiable_fixed_embeddings(
Terry Koo's avatar
Terry Koo committed
360
            self, state, stride, during_training)
Ivan Bogatyy's avatar
Ivan Bogatyy committed
361
      else:
Terry Koo's avatar
Terry Koo committed
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
        if 'use_densors' in self.spec.network_unit.parameters:
          state.handle, fixed_embeddings = fetch_dense_ragged_embeddings(
              self, state)
        else:
          if ('padded_batch_size' in self.spec.network_unit.parameters and
              'padded_sentence_length' in self.spec.network_unit.parameters):
            state.handle, fixed_embeddings = fetch_fast_fixed_embeddings(
                self,
                state,
                pad_to_batch=-1,
                pad_to_steps=int(self.spec.network_unit.parameters[
                    'padded_sentence_length']))

          else:
            state.handle, fixed_embeddings = fetch_fast_fixed_embeddings(
                self, state)
Ivan Bogatyy's avatar
Ivan Bogatyy committed
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393

    linked_embeddings = [
        fetch_linked_embedding(self, network_states, spec)
        for spec in self.spec.linked_feature
    ]

    with tf.variable_scope(self.name, reuse=True):
      tensors = self.network.create(
          fixed_embeddings,
          linked_embeddings,
          None,
          None,
          during_training=during_training,
          stride=stride)

    update_network_states(self, tensors, network_states, stride)
Terry Koo's avatar
Terry Koo committed
394
    self._add_runtime_hooks()
Ivan Bogatyy's avatar
Ivan Bogatyy committed
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
    return state.handle


class BulkFeatureIdExtractorComponentBuilder(component.ComponentBuilderBase):
  """A component builder to bulk extract feature IDs.

  This is a variant of BulkFeatureExtractorComponentBuilder that only supports
  fixed features, and extracts raw feature IDs instead of feature embeddings.
  Since the extracted feature IDs are integers, the results produced by this
  component are in general not differentiable.
  """

  def __init__(self, master, component_spec):
    """Initializes the feature ID extractor component.

    Args:
      master: dragnn.MasterBuilder object.
      component_spec: dragnn.ComponentSpec proto to be built.
    """
    super(BulkFeatureIdExtractorComponentBuilder, self).__init__(
        master, component_spec)
    check.Eq(len(self.spec.linked_feature), 0, 'Linked features are forbidden')
    for feature_spec in self.spec.fixed_feature:
      check.Lt(feature_spec.embedding_dim, 0,
               'Features must be non-embedded: %s' % feature_spec)

  def build_greedy_training(self, state, network_states):
    """See base class."""
    state.handle = self._extract_feature_ids(state, network_states, True)
    cost = self.add_regularizer(tf.constant(0.))
425
426
    correct, total = tf.constant(0), tf.constant(0)
    return state.handle, cost, correct, total
Ivan Bogatyy's avatar
Ivan Bogatyy committed
427
428
429
430

  def build_greedy_inference(self, state, network_states,
                             during_training=False):
    """See base class."""
Terry Koo's avatar
Terry Koo committed
431
432
433
    handle = self._extract_feature_ids(state, network_states, during_training)
    self._add_runtime_hooks()
    return handle
Ivan Bogatyy's avatar
Ivan Bogatyy committed
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452

  def _extract_feature_ids(self, state, network_states, during_training):
    """Extracts feature IDs and advances a batch using the oracle path.

    Args:
      state: MasterState from the 'AdvanceMaster' op that advances the
          underlying master to this component.
      network_states: Dictionary of component NetworkState objects.
      during_training: Whether the graph is being constructed during training.

    Returns:
      state handle: Final state after advancing.
    """
    logging.info('Building component: %s', self.spec.name)

    if during_training:
      stride = state.current_batch_size * self.training_beam_size
    else:
      stride = state.current_batch_size * self.inference_beam_size
Terry Koo's avatar
Terry Koo committed
453
    self.network.pre_create(stride)
Ivan Bogatyy's avatar
Ivan Bogatyy committed
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504

    with tf.variable_scope(self.name, reuse=True):
      state.handle, ids = extract_fixed_feature_ids(self, state, stride)

    with tf.variable_scope(self.name, reuse=True):
      tensors = self.network.create(
          ids, [], None, None, during_training, stride=stride)
    update_network_states(self, tensors, network_states, stride)
    return state.handle


class BulkAnnotatorComponentBuilder(component.ComponentBuilderBase):
  """A component builder to bulk annotate or compute the cost of a gold path.

  This component can be used with features that don't depend on the
  transition system state.

  Since no feature extraction is performed, only non-recurrent
  'identity' linked features are supported.

  If a FeedForwardNetwork is configured with no hidden units, this component
  acts as a 'bulk softmax' component.
  """

  def build_greedy_training(self, state, network_states):
    """Advances a batch using oracle paths, returning the overall CE cost.

    Args:
      state: MasterState from the 'AdvanceMaster' op that advances the
          underlying master to this component.
      network_states: dictionary of component NetworkState objects

    Returns:
      (state handle, cost, correct, total): TF ops corresponding to the final
          state after unrolling, the total cost, the total number of correctly
          predicted actions, and the total number of actions.

    Raises:
      RuntimeError: if fixed features are configured.
    """
    logging.info('Building component: %s', self.spec.name)
    if self.spec.fixed_feature:
      raise RuntimeError(
          'Fixed features are not compatible with bulk annotation. '
          'Use the "bulk-features" component instead.')
    linked_embeddings = [
        fetch_linked_embedding(self, network_states, spec)
        for spec in self.spec.linked_feature
    ]

    stride = state.current_batch_size * self.training_beam_size
Terry Koo's avatar
Terry Koo committed
505
    self.network.pre_create(stride)
Ivan Bogatyy's avatar
Ivan Bogatyy committed
506
507
508
509
510
511
512
513
    with tf.variable_scope(self.name, reuse=True):
      network_tensors = self.network.create([], linked_embeddings, None, None,
                                            True, stride)

    update_network_states(self, network_tensors, network_states, stride)

    state.handle, gold = dragnn_ops.bulk_advance_from_oracle(
        state.handle, component=self.name)
Terry Koo's avatar
Terry Koo committed
514
515
516
517
518
519
    cost, correct, total = self.network.compute_bulk_loss(
        stride, network_tensors, gold)
    if cost is None:
      # The network does not have a custom bulk loss; default to softmax.
      logits = self.network.get_logits(network_tensors)
      cost, correct, total = build_cross_entropy_loss(logits, gold)
Ivan Bogatyy's avatar
Ivan Bogatyy committed
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
    cost = self.add_regularizer(cost)

    return state.handle, cost, correct, total

  def build_greedy_inference(self, state, network_states,
                             during_training=False):
    """Annotates a batch of documents using network scores.

    Args:
      state: MasterState from the 'AdvanceMaster' op that advances the
          underlying master to this component.
      network_states: dictionary of component NetworkState objects
      during_training: whether the graph is being constructed during training

    Returns:
      Handle to the state once inference is complete for this Component.

    Raises:
      RuntimeError: if fixed features are configured
    """
    logging.info('Building component: %s', self.spec.name)
    if self.spec.fixed_feature:
      raise RuntimeError(
          'Fixed features are not compatible with bulk annotation. '
          'Use the "bulk-features" component instead.')
    linked_embeddings = [
        fetch_linked_embedding(self, network_states, spec)
        for spec in self.spec.linked_feature
    ]

    if during_training:
      stride = state.current_batch_size * self.training_beam_size
    else:
      stride = state.current_batch_size * self.inference_beam_size
Terry Koo's avatar
Terry Koo committed
554
    self.network.pre_create(stride)
Ivan Bogatyy's avatar
Ivan Bogatyy committed
555
556

    with tf.variable_scope(self.name, reuse=True):
Terry Koo's avatar
Terry Koo committed
557
558
      network_tensors = self.network.create([], linked_embeddings, None, None,
                                            during_training, stride)
Ivan Bogatyy's avatar
Ivan Bogatyy committed
559
560
561

    update_network_states(self, network_tensors, network_states, stride)

Terry Koo's avatar
Terry Koo committed
562
563
564
565
566
567
568
569
570
    logits = self.network.get_bulk_predictions(stride, network_tensors)
    if logits is None:
      # The network does not produce custom bulk predictions; default to logits.
      logits = self.network.get_logits(network_tensors)
      logits = tf.cond(self.locally_normalize,
                       lambda: tf.nn.log_softmax(logits), lambda: logits)
      if self._output_as_probabilities:
        logits = tf.nn.softmax(logits)
    handle = dragnn_ops.bulk_advance_from_prediction(
Ivan Bogatyy's avatar
Ivan Bogatyy committed
571
        state.handle, logits, component=self.name)
Terry Koo's avatar
Terry Koo committed
572
573
574

    self._add_runtime_hooks()
    return handle