spec_builder_test.py 4.61 KB
Newer Older
Ivan Bogatyy's avatar
Ivan Bogatyy committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
# Copyright 2017 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for the DRAGNN spec builder."""

import os.path
import tempfile

import tensorflow as tf

from google.protobuf import text_format
from dragnn.protos import spec_pb2
from dragnn.python import spec_builder

# Imported for FLAGS.tf_master, which is used in the lexicon module.

from syntaxnet import parser_trainer

import syntaxnet.load_parser_ops

FLAGS = tf.app.flags.FLAGS
if not hasattr(FLAGS, 'test_srcdir'):
  FLAGS.test_srcdir = ''
if not hasattr(FLAGS, 'test_tmpdir'):
  FLAGS.test_tmpdir = tf.test.get_temp_dir()


class SpecBuilderTest(tf.test.TestCase):

  def assertSpecEqual(self, expected_spec_text, spec):
    expected_spec = spec_pb2.ComponentSpec()
    text_format.Parse(expected_spec_text, expected_spec)
    self.assertProtoEquals(expected_spec, spec)

  def testComponentSpecBuilderEmpty(self):
    builder = spec_builder.ComponentSpecBuilder('test')
    self.assertSpecEqual(r"""
name: "test"
backend { registered_name: "SyntaxNetComponent" }
component_builder { registered_name: "DynamicComponentBuilder" }
        """, builder.spec)

  def testComponentSpecBuilderFixedFeature(self):
    builder = spec_builder.ComponentSpecBuilder('test')
    builder.set_network_unit('FeedForwardNetwork', hidden_layer_sizes='64,64')
    builder.set_transition_system('shift-only')
    builder.add_fixed_feature(name='words', fml='input.word', embedding_dim=16)
    self.assertSpecEqual(r"""
name: "test"
fixed_feature { name: "words" fml: "input.word" embedding_dim: 16 }
backend { registered_name: "SyntaxNetComponent" }
component_builder { registered_name: "DynamicComponentBuilder" }
network_unit { registered_name: "FeedForwardNetwork"
               parameters { key: "hidden_layer_sizes" value: "64,64" } }
transition_system { registered_name: "shift-only" }
        """, builder.spec)

  def testComponentSpecBuilderLinkedFeature(self):
    builder1 = spec_builder.ComponentSpecBuilder('test1')
    builder1.set_transition_system('shift-only')
    builder1.add_fixed_feature(name='words', fml='input.word', embedding_dim=16)
    builder2 = spec_builder.ComponentSpecBuilder('test2')
    builder2.set_network_unit('IdentityNetwork')
    builder2.set_transition_system('tagger')
    builder2.add_token_link(
        source=builder1,
        source_layer='words',
        fml='input.focus',
        embedding_dim=-1)
    self.assertSpecEqual(r"""
name: "test2"
linked_feature { name: "test1" source_component: "test1" source_layer: "words"
                 source_translator: "identity" fml: "input.focus"
                 embedding_dim: -1 }
backend { registered_name: "SyntaxNetComponent" }
component_builder { registered_name: "DynamicComponentBuilder" }
network_unit { registered_name: "IdentityNetwork" }
transition_system { registered_name: "tagger" }
        """, builder2.spec)

  def testFillsTaggerTransitions(self):
    lexicon_dir = tempfile.mkdtemp()

    def write_lines(filename, lines):
      with open(os.path.join(lexicon_dir, filename), 'w') as f:
        f.write(''.join('{}\n'.format(line) for line in lines))

    # Label map is required, even though it isn't used
    write_lines('label-map', ['0'])
    write_lines('word-map', ['2', 'miranda 1', 'rights 1'])
    write_lines('tag-map', ['2', 'NN 1', 'NNP 1'])
    write_lines('tag-to-category', ['NN\tNOUN', 'NNP\tNOUN'])

    tagger = spec_builder.ComponentSpecBuilder('tagger')
    tagger.set_network_unit(name='FeedForwardNetwork', hidden_layer_sizes='256')
    tagger.set_transition_system(name='tagger')
    tagger.add_fixed_feature(name='words', fml='input.word', embedding_dim=64)
    tagger.add_rnn_link(embedding_dim=-1)
    tagger.fill_from_resources(lexicon_dir)

    fixed_feature, = tagger.spec.fixed_feature
    linked_feature, = tagger.spec.linked_feature
    self.assertEqual(fixed_feature.vocabulary_size, 5)
    self.assertEqual(fixed_feature.size, 1)
    self.assertEqual(fixed_feature.size, 1)
    self.assertEqual(linked_feature.size, 1)
    self.assertEqual(tagger.spec.num_actions, 2)


if __name__ == '__main__':
  tf.test.main()