Commit ea3fa4a3 authored by Ivan Bogatyy's avatar Ivan Bogatyy
Browse files

Update DRAGNN, fix some macOS issues

parent b7523ee5
# Copyright 2017 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Component builders for non-recurrent networks in DRAGNN."""
......@@ -249,7 +264,8 @@ class BulkFeatureExtractorComponentBuilder(component.ComponentBuilderBase):
update_network_states(self, tensors, network_states, stride)
cost = self.add_regularizer(tf.constant(0.))
return state.handle, cost, 0, 0
correct, total = tf.constant(0), tf.constant(0)
return state.handle, cost, correct, total
def build_greedy_inference(self, state, network_states,
during_training=False):
......@@ -327,7 +343,8 @@ class BulkFeatureIdExtractorComponentBuilder(component.ComponentBuilderBase):
"""See base class."""
state.handle = self._extract_feature_ids(state, network_states, True)
cost = self.add_regularizer(tf.constant(0.))
return state.handle, cost, 0, 0
correct, total = tf.constant(0), tf.constant(0)
return state.handle, cost, correct, total
def build_greedy_inference(self, state, network_states,
during_training=False):
......
# Copyright 2017 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for bulk_component.
Verifies that:
......
# Copyright 2017 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Builds a DRAGNN graph for local training."""
from abc import ABCMeta
......@@ -147,6 +162,32 @@ class ComponentBuilderBase(object):
"""
pass
def build_structured_training(self, state, network_states):
"""Builds a beam search based training loop for this component.
The default implementation builds a dummy graph and raises a
TensorFlow runtime exception to indicate that structured training
is not implemented.
Args:
state: MasterState from the 'AdvanceMaster' op that advances the
underlying master to this component.
network_states: dictionary of component NetworkState objects.
Returns:
(handle, cost, correct, total) -- These are TF ops corresponding
to the final handle after unrolling, the total cost, and the
total number of actions. Since the number of correctly predicted
actions is not applicable in the structured training setting, a
dummy value should returned.
"""
del network_states # Unused.
with tf.control_dependencies([tf.Assert(False, ['Not implemented.'])]):
handle = tf.identity(state.handle)
cost = tf.constant(0.)
correct, total = tf.constant(0), tf.constant(0)
return handle, cost, correct, total
@abstractmethod
def build_greedy_inference(self, state, network_states,
during_training=False):
......@@ -349,14 +390,13 @@ class DynamicComponentBuilder(ComponentBuilderBase):
correctly predicted actions, and the total number of actions.
"""
logging.info('Building component: %s', self.spec.name)
stride = state.current_batch_size * self.training_beam_size
with tf.control_dependencies([tf.assert_equal(self.training_beam_size, 1)]):
stride = state.current_batch_size * self.training_beam_size
cost = tf.constant(0.)
correct = tf.constant(0)
total = tf.constant(0)
# Create the TensorArray's to store activations for downstream/recurrent
# connections.
def cond(handle, *_):
all_final = dragnn_ops.emit_all_final(handle, component=self.name)
return tf.logical_not(tf.reduce_all(all_final))
......
# Copyright 2017 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""An optimizer that switches between several methods."""
import tensorflow as tf
......
"""Tests for CompositeOptimizer.
"""
# Copyright 2017 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for CompositeOptimizer."""
import numpy as np
......
# Copyright 2017 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""TensorFlow ops for directed graphs."""
import tensorflow as tf
......
# Copyright 2017 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for digraph ops."""
import tensorflow as tf
......
# Copyright 2017 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Groups the DRAGNN TensorFlow ops in one module."""
......
# Copyright 2017 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Builds a DRAGNN graph for local training."""
......@@ -65,6 +80,13 @@ def _create_optimizer(hyperparams, learning_rate_var, step_var=None):
beta2=hyperparams.adam_beta2,
epsilon=hyperparams.adam_eps,
use_locking=True)
elif hyperparams.learning_method == 'lazyadam':
return tf.contrib.opt.LazyAdamOptimizer(
learning_rate_var,
beta1=hyperparams.adam_beta1,
beta2=hyperparams.adam_beta2,
epsilon=hyperparams.adam_eps,
use_locking=True)
elif hyperparams.learning_method == 'momentum':
return tf.train.MomentumOptimizer(
learning_rate_var, hyperparams.momentum, use_locking=True)
......@@ -138,6 +160,10 @@ class MasterBuilder(object):
if hyperparam_config is None else hyperparam_config)
self.pool_scope = pool_scope
# Set the graph-level random seed before creating the Components so the ops
# they create will use this seed.
tf.set_random_seed(hyperparam_config.seed)
# Construct all utility class and variables for each Component.
self.components = []
self.lookup_component = {}
......@@ -318,15 +344,18 @@ class MasterBuilder(object):
dragnn_ops.batch_size(
handle, component=comp.name))
with tf.control_dependencies([handle, cost]):
component_cost = tf.constant(0.)
component_correct = tf.constant(0)
component_total = tf.constant(0)
args = (master_state, network_states)
if unroll_using_oracle[component_index]:
handle, component_cost, component_correct, component_total = (
comp.build_greedy_training(master_state, network_states))
handle, component_cost, component_correct, component_total = (tf.cond(
comp.training_beam_size > 1,
lambda: comp.build_structured_training(*args),
lambda: comp.build_greedy_training(*args)))
else:
handle = comp.build_greedy_inference(
master_state, network_states, during_training=True)
handle = comp.build_greedy_inference(*args, during_training=True)
component_cost = tf.constant(0.)
component_correct, component_total = tf.constant(0), tf.constant(0)
weighted_component_cost = tf.multiply(
component_cost,
......@@ -497,30 +526,23 @@ class MasterBuilder(object):
with tf.name_scope(scope_id):
# Construct training targets. Disable tracing during training.
handle, input_batch = self._get_session_with_reader(trace_only)
# If `trace_only` is True, the training graph shouldn't have any
# side effects. Otherwise, the standard training scenario should
# generate gradients and update counters.
handle, outputs = self.build_training(
handle,
compute_gradients=not trace_only,
advance_counters=not trace_only,
component_weights=target_config.component_weights,
unroll_using_oracle=target_config.unroll_using_oracle,
max_index=target_config.max_index,
**kwargs)
if trace_only:
# Build a training graph that doesn't have any side effects.
handle, outputs = self.build_training(
handle,
compute_gradients=False,
advance_counters=False,
component_weights=target_config.component_weights,
unroll_using_oracle=target_config.unroll_using_oracle,
max_index=target_config.max_index,
**kwargs)
outputs['traces'] = dragnn_ops.get_component_trace(
handle, component=self.spec.component[-1].name)
else:
# The standard training scenario has gradients and updates counters.
handle, outputs = self.build_training(
handle,
compute_gradients=True,
advance_counters=True,
component_weights=target_config.component_weights,
unroll_using_oracle=target_config.unroll_using_oracle,
max_index=target_config.max_index,
**kwargs)
# In addition, it keeps track of the number of training steps.
# Standard training keeps track of the number of training steps.
outputs['target_step'] = tf.get_variable(
scope_id + '/TargetStep', [],
initializer=tf.zeros_initializer(),
......
# Copyright 2017 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for graph_builder."""
......@@ -517,6 +532,23 @@ class GraphBuilderTest(test_util.TensorFlowTestCase):
expected_num_actions=9,
expected=_TAGGER_PARSER_EXPECTED_SENTENCES)
def testStructuredTrainingNotImplementedDeath(self):
spec = self.LoadSpec('simple_parser_master_spec.textproto')
# Make the 'parser' component have a beam at training time.
self.assertEqual('parser', spec.component[0].name)
spec.component[0].training_beam_size = 8
# The training run should fail at runtime rather than build time.
with self.assertRaisesRegexp(tf.errors.InvalidArgumentError,
r'\[Not implemented.\]'):
self.RunFullTrainingAndInference(
'simple-parser',
master_spec=spec,
expected_num_actions=8,
component_weights=[1],
expected=_LABELED_PARSER_EXPECTED_SENTENCES)
def testSimpleParser(self):
self.RunFullTrainingAndInference(
'simple-parser',
......
# Copyright 2017 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Basic network units used in assembling DRAGNN graphs."""
from abc import ABCMeta
......@@ -88,7 +103,7 @@ class NamedTensor(object):
self.dim = dim
def add_embeddings(channel_id, feature_spec, seed):
def add_embeddings(channel_id, feature_spec, seed=None):
"""Adds a variable for the embedding of a given fixed feature.
Supports pre-trained or randomly initialized embeddings In both cases, extra
......@@ -119,11 +134,14 @@ def add_embeddings(channel_id, feature_spec, seed):
if len(feature_spec.vocab.part) > 1:
raise RuntimeError('vocab resource contains more than one part:\n%s',
str(feature_spec.vocab))
seed1, seed2 = tf.get_seed(seed)
embeddings = dragnn_ops.dragnn_embedding_initializer(
embedding_input=feature_spec.pretrained_embedding_matrix.part[0]
.file_pattern,
vocab=feature_spec.vocab.part[0].file_pattern,
scaling_coefficient=1.0)
scaling_coefficient=1.0,
seed=seed1,
seed2=seed2)
return tf.get_variable(name, initializer=tf.reshape(embeddings, shape))
else:
return tf.get_variable(
......@@ -622,7 +640,6 @@ class NetworkUnitInterface(object):
init_layers: optional initial layers.
init_context_layers: optional initial context layers.
"""
self._seed = component.master.hyperparams.seed
self._component = component
self._params = []
self._layers = init_layers if init_layers else []
......@@ -640,7 +657,7 @@ class NetworkUnitInterface(object):
check.Gt(spec.size, 0, 'Invalid fixed feature size')
if spec.embedding_dim > 0:
fixed_dim = spec.embedding_dim
self._params.append(add_embeddings(channel_id, spec, self._seed))
self._params.append(add_embeddings(channel_id, spec))
else:
fixed_dim = 1 # assume feature ID extraction; only one ID per step
self._fixed_feature_dims[spec.name] = spec.size * fixed_dim
......@@ -663,7 +680,7 @@ class NetworkUnitInterface(object):
linked_embeddings_name(channel_id),
[source_array_dim + 1, spec.embedding_dim],
initializer=tf.random_normal_initializer(
stddev=1 / spec.embedding_dim**.5, seed=self._seed)))
stddev=1 / spec.embedding_dim**.5)))
self._linked_feature_dims[spec.name] = spec.size * spec.embedding_dim
else:
......@@ -698,14 +715,12 @@ class NetworkUnitInterface(object):
tf.get_variable(
'attention_weights_pm_0',
[attention_hidden_layer_size, hidden_layer_size],
initializer=tf.random_normal_initializer(
stddev=1e-4, seed=self._seed)))
initializer=tf.random_normal_initializer(stddev=1e-4)))
self._params.append(
tf.get_variable(
'attention_weights_hm_0', [hidden_layer_size, hidden_layer_size],
initializer=tf.random_normal_initializer(
stddev=1e-4, seed=self._seed)))
initializer=tf.random_normal_initializer(stddev=1e-4)))
self._params.append(
tf.get_variable(
......@@ -721,8 +736,7 @@ class NetworkUnitInterface(object):
tf.get_variable(
'attention_weights_pu',
[attention_hidden_layer_size, component.num_actions],
initializer=tf.random_normal_initializer(
stddev=1e-4, seed=self._seed)))
initializer=tf.random_normal_initializer(stddev=1e-4)))
@abstractmethod
def create(self,
......@@ -961,8 +975,7 @@ class FeedForwardNetwork(NetworkUnitInterface):
for index, hidden_layer_size in enumerate(self._hidden_layer_sizes):
weights = tf.get_variable(
'weights_%d' % index, [last_layer_dim, hidden_layer_size],
initializer=tf.random_normal_initializer(stddev=1e-4,
seed=self._seed))
initializer=tf.random_normal_initializer(stddev=1e-4))
self._params.append(weights)
if index > 0 or self._layer_norm_hidden is None:
self._params.append(
......@@ -988,8 +1001,7 @@ class FeedForwardNetwork(NetworkUnitInterface):
self._params.append(
tf.get_variable(
'weights_softmax', [last_layer_dim, component.num_actions],
initializer=tf.random_normal_initializer(
stddev=1e-4, seed=self._seed)))
initializer=tf.random_normal_initializer(stddev=1e-4)))
self._params.append(
tf.get_variable(
'bias_softmax', [component.num_actions],
......@@ -1106,47 +1118,39 @@ class LSTMNetwork(NetworkUnitInterface):
# e.g. truncated_normal_initializer?
self._x2i = tf.get_variable(
'x2i', [layer_input_dim, self._hidden_layer_sizes],
initializer=tf.random_normal_initializer(stddev=1e-4,
seed=self._seed))
initializer=tf.random_normal_initializer(stddev=1e-4))
self._h2i = tf.get_variable(
'h2i', [self._hidden_layer_sizes, self._hidden_layer_sizes],
initializer=tf.random_normal_initializer(stddev=1e-4,
seed=self._seed))
initializer=tf.random_normal_initializer(stddev=1e-4))
self._c2i = tf.get_variable(
'c2i', [self._hidden_layer_sizes, self._hidden_layer_sizes],
initializer=tf.random_normal_initializer(stddev=1e-4,
seed=self._seed))
initializer=tf.random_normal_initializer(stddev=1e-4))
self._bi = tf.get_variable(
'bi', [self._hidden_layer_sizes],
initializer=tf.random_normal_initializer(stddev=1e-4, seed=self._seed))
initializer=tf.random_normal_initializer(stddev=1e-4))
self._x2o = tf.get_variable(
'x2o', [layer_input_dim, self._hidden_layer_sizes],
initializer=tf.random_normal_initializer(stddev=1e-4,
seed=self._seed))
initializer=tf.random_normal_initializer(stddev=1e-4))
self._h2o = tf.get_variable(
'h2o', [self._hidden_layer_sizes, self._hidden_layer_sizes],
initializer=tf.random_normal_initializer(stddev=1e-4,
seed=self._seed))
initializer=tf.random_normal_initializer(stddev=1e-4))
self._c2o = tf.get_variable(
'c2o', [self._hidden_layer_sizes, self._hidden_layer_sizes],
initializer=tf.random_normal_initializer(stddev=1e-4,
seed=self._seed))
initializer=tf.random_normal_initializer(stddev=1e-4))
self._bo = tf.get_variable(
'bo', [self._hidden_layer_sizes],
initializer=tf.random_normal_initializer(stddev=1e-4, seed=self._seed))
initializer=tf.random_normal_initializer(stddev=1e-4))
self._x2c = tf.get_variable(
'x2c', [layer_input_dim, self._hidden_layer_sizes],
initializer=tf.random_normal_initializer(stddev=1e-4,
seed=self._seed))
initializer=tf.random_normal_initializer(stddev=1e-4))
self._h2c = tf.get_variable(
'h2c', [self._hidden_layer_sizes, self._hidden_layer_sizes],
initializer=tf.random_normal_initializer(stddev=1e-4,
seed=self._seed))
initializer=tf.random_normal_initializer(stddev=1e-4))
self._bc = tf.get_variable(
'bc', [self._hidden_layer_sizes],
initializer=tf.random_normal_initializer(stddev=1e-4, seed=self._seed))
initializer=tf.random_normal_initializer(stddev=1e-4))
self._params.extend([
self._x2i, self._h2i, self._c2i, self._bi, self._x2o, self._h2o,
......@@ -1166,8 +1170,7 @@ class LSTMNetwork(NetworkUnitInterface):
self.params.append(tf.get_variable(
'weights_softmax', [self._hidden_layer_sizes, component.num_actions],
initializer=tf.random_normal_initializer(stddev=1e-4,
seed=self._seed)))
initializer=tf.random_normal_initializer(stddev=1e-4)))
self.params.append(
tf.get_variable(
'bias_softmax', [component.num_actions],
......@@ -1324,8 +1327,7 @@ class ConvNetwork(NetworkUnitInterface):
tf.get_variable(
'weights',
self.kernel_shapes[i],
initializer=tf.random_normal_initializer(
stddev=1e-4, seed=self._seed),
initializer=tf.random_normal_initializer(stddev=1e-4),
dtype=tf.float32))
bias_init = 0.0 if (i == len(self._widths) - 1) else 0.2
self._biases.append(
......@@ -1473,8 +1475,7 @@ class PairwiseConvNetwork(NetworkUnitInterface):
tf.get_variable(
'weights',
kernel_shape,
initializer=tf.random_normal_initializer(
stddev=1e-4, seed=self._seed),
initializer=tf.random_normal_initializer(stddev=1e-4),
dtype=tf.float32))
bias_init = 0.0 if i in self._relu_layers else 0.2
self._biases.append(
......
# Copyright 2017 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for network_units."""
......
# -*- coding: utf-8 -*-
# Copyright 2017 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Renders parse trees with Graphviz."""
from __future__ import absolute_import
from __future__ import division
......
# Copyright 2017 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for ....dragnn.python.render_parse_tree_graphviz."""
from __future__ import absolute_import
......
# -*- coding: utf-8 -*-
# Copyright 2017 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Renders DRAGNN specs with Graphviz."""
from __future__ import absolute_import
from __future__ import division
......
# Copyright 2017 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for render_spec_with_graphviz."""
from __future__ import absolute_import
......
# Copyright 2017 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utilities for reading and writing sentences in dragnn."""
import tensorflow as tf
from syntaxnet.ops import gen_parser_ops
......
# Copyright 2017 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import os
import tensorflow as tf
......
......@@ -36,16 +36,20 @@ class ComponentSpecBuilder(object):
spec: The dragnn.ComponentSpec proto.
"""
def __init__(self, name, builder='DynamicComponentBuilder'):
def __init__(self,
name,
builder='DynamicComponentBuilder',
backend='SyntaxNetComponent'):
"""Initializes the ComponentSpec with some defaults for SyntaxNet.
Args:
name: The name of this Component in the pipeline.
builder: The component builder type.
backend: The component backend type.
"""
self.spec = spec_pb2.ComponentSpec(
name=name,
backend=self.make_module('SyntaxNetComponent'),
backend=self.make_module(backend),
component_builder=self.make_module(builder))
def make_module(self, name, **kwargs):
......
# Copyright 2017 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utility functions to build DRAGNN MasterSpecs and schedule model training.
Provides functions to finish a MasterSpec, building required lexicons for it and
......@@ -27,7 +42,7 @@ def calculate_component_accuracies(eval_res_values):
]
def _write_summary(summary_writer, label, value, step):
def write_summary(summary_writer, label, value, step):
"""Write a summary for a certain evaluation."""
summary = Summary(value=[Summary.Value(tag=label, simple_value=float(value))])
summary_writer.add_summary(summary, step)
......@@ -135,7 +150,7 @@ def run_training(sess, trainers, annotator, evaluator, pretrain_steps,
annotated = annotate_dataset(sess, annotator, eval_corpus)
summaries = evaluator(eval_gold, annotated)
for label, metric in summaries.iteritems():
_write_summary(summary_writer, label, metric, actual_step + step)
write_summary(summary_writer, label, metric, actual_step + step)
eval_metric = summaries['eval_metric']
if best_eval_metric < eval_metric:
tf.logging.info('Updating best eval to %.2f%%, saving checkpoint.',
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment