mst_units.py 6.26 KB
Newer Older
Terry Koo's avatar
Terry Koo committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
# Copyright 2018 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""DRAGNN wrappers for the MST solver."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import tensorflow as tf

from dragnn.python import mst_ops
from dragnn.python import network_units
from syntaxnet.util import check


class MstSolverNetwork(network_units.NetworkUnitInterface):
  """Network unit that performs MST prediction with structured loss.

  Parameters:
    forest: If true, solve for a spanning forest instead of a spanning tree.
    loss: The loss function for training.  Select from
      softmax: Default unstructured softmax (prediction is still structured).
      m3n: Max-Margin Markov Networks loss.
    crf_max_dynamic_range: Max dynamic range for the log partition function.

  Links:
    lengths: [B, 1] sequence lengths per batch item.
    scores: [B * N, N] matrix of padded batched arc scores.

  Layers:
    lengths: [B] sequence lengths per batch item.
    scores: [B, N, N] tensor of padded batched arc scores.
    logits: [B * N, N] matrix of padded batched arc scores.
    arcs: [B * N, N] matrix of padded batched 0/1 indicators for MST arcs.
  """

  def __init__(self, component):
    """Initializes layers.

    Args:
      component: Parent ComponentBuilderBase object.
    """
    layers = [
        network_units.Layer(self, 'lengths', -1),
        network_units.Layer(self, 'scores', -1),
        network_units.Layer(self, 'logits', -1),
        network_units.Layer(self, 'arcs', -1),
    ]
    super(MstSolverNetwork, self).__init__(component, init_layers=layers)

    self._attrs = network_units.get_attrs_with_defaults(
        component.spec.network_unit.parameters,
        defaults={
            'forest': False,
            'loss': 'softmax',
            'crf_max_dynamic_range': 20,
        })

    check.Eq(
        len(self._fixed_feature_dims.items()), 0, 'Expected no fixed features')
    check.Eq(
        len(self._linked_feature_dims.items()), 2,
        'Expected two linked features')

    check.In('lengths', self._linked_feature_dims,
             'Missing required linked feature')
    check.In('scores', self._linked_feature_dims,
             'Missing required linked feature')

  def create(self,
             fixed_embeddings,
             linked_embeddings,
             context_tensor_arrays,
             attention_tensor,
             during_training,
             stride=None):
    """Forwards the lengths and scores."""
    check.NotNone(stride, 'MstSolverNetwork requires stride')

    lengths = network_units.lookup_named_tensor('lengths', linked_embeddings)
    lengths_b = tf.to_int32(tf.squeeze(lengths.tensor, [1]))

    scores = network_units.lookup_named_tensor('scores', linked_embeddings)
    scores_bnxn = scores.tensor
    max_length = tf.shape(scores_bnxn)[1]
    scores_bxnxn = tf.reshape(scores_bnxn, [stride, max_length, max_length])

    _, argmax_sources_bxn = mst_ops.maximum_spanning_tree(
        forest=self._attrs['forest'], num_nodes=lengths_b, scores=scores_bxnxn)
    argmax_sources_bn = tf.reshape(argmax_sources_bxn, [-1])
    arcs_bnxn = tf.one_hot(argmax_sources_bn, max_length, dtype=tf.float32)

    return [lengths_b, scores_bxnxn, scores_bnxn, arcs_bnxn]

  def get_logits(self, network_tensors):
    return network_tensors[self.get_layer_index('logits')]

  def get_bulk_predictions(self, stride, network_tensors):
    return network_tensors[self.get_layer_index('arcs')]

  def compute_bulk_loss(self, stride, network_tensors, gold):
    """See base class."""
    if self._attrs['loss'] == 'softmax':
      return (None, None, None)  # fall back to default bulk softmax

    lengths_b, scores_bxnxn, _, arcs_bnxn = network_tensors
    max_length = tf.shape(scores_bxnxn)[2]
    arcs_bxnxn = tf.reshape(arcs_bnxn, [stride, max_length, max_length])
    gold_bxn = tf.reshape(gold, [stride, max_length])
    gold_bxnxn = tf.one_hot(gold_bxn, max_length, dtype=tf.float32)

    loss = self._compute_loss(lengths_b, scores_bxnxn, gold_bxnxn)
    correct = tf.reduce_sum(tf.to_int32(arcs_bxnxn * gold_bxnxn))
    total = tf.reduce_sum(lengths_b)
    return loss, correct, total

  def _compute_loss(self, lengths, scores, gold):
    """Computes the configured structured loss for a batch.

    Args:
      lengths: [B] sequence lengths per batch item.
      scores: [B, N, N] tensor of padded batched arc scores.
      gold: [B, N, N] tensor of 0/1 indicators for gold arcs.

    Returns:
      Scalar sum of losses across the batch.
    """
    # Dispatch to one of the _compute_*_loss() methods.
    method_name = '_compute_%s_loss' % self._attrs['loss']
    loss_b = getattr(self, method_name)(lengths, scores, gold)
    return tf.reduce_sum(loss_b)

  def _compute_m3n_loss(self, lengths, scores, gold):
    """Computes the M3N-style structured hinge loss for a batch."""
    # Perform hamming-loss-augmented inference.
    gold_scores_b = tf.reduce_sum(scores * gold, axis=[1, 2])
    hamming_loss_bxnxn = 1 - gold
    scores_bxnxn = scores + hamming_loss_bxnxn
    max_scores_b, _ = mst_ops.maximum_spanning_tree(
        num_nodes=lengths, scores=scores_bxnxn, forest=self._attrs['forest'])
    return max_scores_b - gold_scores_b

  def _compute_crf_loss(self, lengths, scores, gold):
    """Computes the negative CRF log-probability for a batch."""
    # The |scores| are assumed to be in the log domain.
    log_gold_scores_b = tf.reduce_sum(scores * gold, axis=[1, 2])
    log_partition_functions_b = mst_ops.log_partition_function(
        num_nodes=lengths,
        scores=scores,
        forest=self._attrs['forest'],
        max_dynamic_range=self._attrs['crf_max_dynamic_range'])
    return log_partition_functions_b - log_gold_scores_b  # negative log-prob