Commit 4364390a authored by Ivan Bogatyy's avatar Ivan Bogatyy Committed by calberti
Browse files

Release DRAGNN bulk networks (#2785)

* Release DRAGNN bulk networks
parent 638fd759
...@@ -55,9 +55,6 @@ from dragnn.python import trainer_lib ...@@ -55,9 +55,6 @@ from dragnn.python import trainer_lib
from syntaxnet.ops import gen_parser_ops from syntaxnet.ops import gen_parser_ops
from syntaxnet.util import check from syntaxnet.util import check
import dragnn.python.load_dragnn_cc_impl
import syntaxnet.load_parser_ops
flags = tf.app.flags flags = tf.app.flags
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
......
...@@ -56,7 +56,7 @@ setuptools.setup( ...@@ -56,7 +56,7 @@ setuptools.setup(
version='0.2', version='0.2',
description='SyntaxNet: Neural Models of Syntax', description='SyntaxNet: Neural Models of Syntax',
long_description='', long_description='',
url='https://github.com/tensorflow/models/tree/master/research/syntaxnet', url='https://github.com/tensorflow/models/tree/master/syntaxnet',
author='Google Inc.', author='Google Inc.',
author_email='opensource@google.com', author_email='opensource@google.com',
......
...@@ -33,9 +33,6 @@ from syntaxnet import sentence_pb2 ...@@ -33,9 +33,6 @@ from syntaxnet import sentence_pb2
from syntaxnet.ops import gen_parser_ops from syntaxnet.ops import gen_parser_ops
from syntaxnet.util import check from syntaxnet.util import check
import dragnn.python.load_dragnn_cc_impl
import syntaxnet.load_parser_ops
flags = tf.app.flags flags = tf.app.flags
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
......
# Copyright 2017 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
r"""Runs a both a segmentation and parsing model on a CoNLL dataset.
"""
import re
import time
import tensorflow as tf
from tensorflow.python.client import timeline
from tensorflow.python.platform import gfile
# The following line is necessary to load custom ops into the library.
from dragnn.python import dragnn_ops
from dragnn.python import evaluation
from dragnn.python import sentence_io
from syntaxnet import sentence_pb2
# The following line is necessary to load custom ops into the library.
from syntaxnet import syntaxnet_ops
from syntaxnet.ops import gen_parser_ops
from syntaxnet.util import check
flags = tf.app.flags
FLAGS = flags.FLAGS
flags.DEFINE_string(
'segmenter_saved_model', None,
'Path to segmenter saved model. If not provided, gold segmentation is used.'
)
flags.DEFINE_string('parser_saved_model', None, 'Path to parser saved model.')
flags.DEFINE_string('input_file', '',
'File of CoNLL-formatted sentences to read from.')
flags.DEFINE_string('output_file', '',
'File path to write annotated sentences to.')
flags.DEFINE_bool('text_format', False, '')
flags.DEFINE_integer('max_batch_size', 2048, 'Maximum batch size to support.')
flags.DEFINE_string('inference_beam_size', '', 'Comma separated list of '
'component_name=beam_size pairs.')
flags.DEFINE_string('locally_normalize', '', 'Comma separated list of '
'component names to do local normalization on.')
flags.DEFINE_integer('threads', 10, 'Number of threads used for intra- and '
'inter-op parallelism.')
flags.DEFINE_string('timeline_output_file', '', 'Path to save timeline to. '
'If specified, the final iteration of the evaluation loop '
'will capture and save a TensorFlow timeline.')
def get_segmenter_corpus(input_data_path, use_text_format):
"""Reads in a character corpus for segmenting."""
# Read in the documents.
tf.logging.info('Reading documents...')
if use_text_format:
char_corpus = sentence_io.FormatSentenceReader(input_data_path,
'untokenized-text').corpus()
else:
input_corpus = sentence_io.ConllSentenceReader(input_data_path).corpus()
with tf.Session(graph=tf.Graph()) as tmp_session:
char_input = gen_parser_ops.char_token_generator(input_corpus)
char_corpus = tmp_session.run(char_input)
check.Eq(len(input_corpus), len(char_corpus))
return char_corpus
def run_segmenter(input_data, segmenter_model, session_config, max_batch_size,
timeline_output_file=None):
"""Runs the provided segmenter model on the provided character corpus.
Args:
input_data: Character input corpus to segment.
segmenter_model: Path to a SavedModel file containing the segmenter graph.
session_config: A session configuration object.
max_batch_size: The maximum batch size to use.
timeline_output_file: Filepath for timeline export. Does not export if None.
Returns:
A list of segmented sentences suitable for parsing.
"""
# Create the session and graph, and load the SavedModel.
g = tf.Graph()
with tf.Session(graph=g, config=session_config) as sess:
tf.logging.info('Initializing segmentation model...')
tf.saved_model.loader.load(sess, [tf.saved_model.tag_constants.SERVING],
segmenter_model)
# Use the graph to segment the sentences.
tf.logging.info('Segmenting sentences...')
processed = []
start_time = time.time()
run_metadata = tf.RunMetadata()
for start in range(0, len(input_data), max_batch_size):
# Prepare the inputs.
end = min(start + max_batch_size, len(input_data))
feed_dict = {
'annotation/ComputeSession/InputBatch:0': input_data[start:end]
}
output_node = 'annotation/annotations:0'
# Process.
tf.logging.info('Processing examples %d to %d' % (start, end))
if timeline_output_file and end == len(input_data):
serialized_annotations = sess.run(
output_node,
feed_dict=feed_dict,
options=tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE),
run_metadata=run_metadata)
trace = timeline.Timeline(step_stats=run_metadata.step_stats)
with open(timeline_output_file, 'w') as trace_file:
trace_file.write(trace.generate_chrome_trace_format())
else:
serialized_annotations = sess.run(output_node, feed_dict=feed_dict)
# Save the outputs.
processed.extend(serialized_annotations)
# Report statistics.
tf.logging.info('Segmented %d documents in %.2f seconds.',
len(input_data), time.time() - start_time)
# Once all sentences are segmented, the processed data can be used in the
# parsers.
return processed
def run_parser(input_data, parser_model, session_config, beam_sizes,
locally_normalized_components, max_batch_size,
timeline_output_file):
"""Runs the provided segmenter model on the provided character corpus.
Args:
input_data: Input corpus to parse.
parser_model: Path to a SavedModel file containing the parser graph.
session_config: A session configuration object.
beam_sizes: A dict of component names : beam sizes (optional).
locally_normalized_components: A list of components to normalize (optional).
max_batch_size: The maximum batch size to use.
timeline_output_file: Filepath for timeline export. Does not export if None.
Returns:
A list of parsed sentences.
"""
parser_graph = tf.Graph()
with tf.Session(graph=parser_graph, config=session_config) as sess:
tf.logging.info('Initializing parser model...')
tf.saved_model.loader.load(sess, [tf.saved_model.tag_constants.SERVING],
parser_model)
tf.logging.info('Parsing sentences...')
processed = []
start_time = time.time()
run_metadata = tf.RunMetadata()
tf.logging.info('Corpus length is %d' % len(input_data))
for start in range(0, len(input_data), max_batch_size):
# Set up the input and output.
end = min(start + max_batch_size, len(input_data))
feed_dict = {
'annotation/ComputeSession/InputBatch:0': input_data[start:end]
}
for comp, beam_size in beam_sizes:
feed_dict['%s/InferenceBeamSize:0' % comp] = beam_size
for comp in locally_normalized_components:
feed_dict['%s/LocallyNormalize:0' % comp] = True
output_node = 'annotation/annotations:0'
# Process.
tf.logging.info('Processing examples %d to %d' % (start, end))
if timeline_output_file and end == len(input_data):
serialized_annotations = sess.run(
output_node,
feed_dict=feed_dict,
options=tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE),
run_metadata=run_metadata)
trace = timeline.Timeline(step_stats=run_metadata.step_stats)
with open(timeline_output_file, 'w') as trace_file:
trace_file.write(trace.generate_chrome_trace_format())
else:
serialized_annotations = sess.run(output_node, feed_dict=feed_dict)
processed.extend(serialized_annotations)
tf.logging.info('Processed %d documents in %.2f seconds.',
len(input_data), time.time() - start_time)
_, uas, las = evaluation.calculate_parse_metrics(input_data, processed)
tf.logging.info('UAS: %.2f', uas)
tf.logging.info('LAS: %.2f', las)
return processed
def print_output(output_file, use_text_format, use_gold_segmentation, output):
"""Writes a set of sentences in CoNLL format.
Args:
output_file: The file to write to.
use_text_format: Whether this computation used text-format input.
use_gold_segmentation: Whether this computation used gold segmentation.
output: A list of sentences to write to the output file.
"""
with gfile.GFile(output_file, 'w') as f:
f.write('## tf:{}\n'.format(use_text_format))
f.write('## gs:{}\n'.format(use_gold_segmentation))
for serialized_sentence in output:
sentence = sentence_pb2.Sentence()
sentence.ParseFromString(serialized_sentence)
f.write('# text = {}\n'.format(sentence.text.encode('utf-8')))
for i, token in enumerate(sentence.token):
head = token.head + 1
f.write('%s\t%s\t_\t_\t_\t_\t%d\t%s\t_\t_\n' %
(i + 1, token.word.encode('utf-8'), head,
token.label.encode('utf-8')))
f.write('\n')
def main(unused_argv):
# Validate that we have a parser saved model passed to this script.
if FLAGS.parser_saved_model is None:
tf.logging.fatal('A parser saved model must be provided.')
# Parse the flags containint lists, using regular expressions.
# This matches and extracts key=value pairs.
component_beam_sizes = re.findall(r'([^=,]+)=(\d+)',
FLAGS.inference_beam_size)
tf.logging.info('Found beam size dict %s' % component_beam_sizes)
# This matches strings separated by a comma. Does not return any empty
# strings.
components_to_locally_normalize = re.findall(r'[^,]+',
FLAGS.locally_normalize)
tf.logging.info(
'Found local normalization dict %s' % components_to_locally_normalize)
# Create a session config with the requested number of threads.
session_config = tf.ConfigProto(
log_device_placement=False,
intra_op_parallelism_threads=FLAGS.threads,
inter_op_parallelism_threads=FLAGS.threads)
# Get the segmented input data for the parser, either by running the
# segmenter ourselves or by simply reading it from the CoNLL file.
if FLAGS.segmenter_saved_model is None:
# If no segmenter was provided, we must use the data from the CONLL file.
input_file = FLAGS.input_file
parser_input = sentence_io.ConllSentenceReader(input_file).corpus()
use_gold_segmentation = True
else:
# If the segmenter was provided, use it.
segmenter_input = get_segmenter_corpus(FLAGS.input_file, FLAGS.text_format)
parser_input = run_segmenter(segmenter_input, FLAGS.segmenter_saved_model,
session_config, FLAGS.max_batch_size,
FLAGS.timeline_output_file)
use_gold_segmentation = False
# Now that we have parser input data, parse.
processed = run_parser(parser_input, FLAGS.parser_saved_model, session_config,
component_beam_sizes, components_to_locally_normalize,
FLAGS.max_batch_size, FLAGS.timeline_output_file)
if FLAGS.output_file:
print_output(FLAGS.output_file, FLAGS.text_format, use_gold_segmentation,
processed)
if __name__ == '__main__':
tf.app.run()
...@@ -40,9 +40,6 @@ from dragnn.python import sentence_io ...@@ -40,9 +40,6 @@ from dragnn.python import sentence_io
from dragnn.python import spec_builder from dragnn.python import spec_builder
from dragnn.python import trainer_lib from dragnn.python import trainer_lib
import dragnn.python.load_dragnn_cc_impl
import syntaxnet.load_parser_ops
flags = tf.app.flags flags = tf.app.flags
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
......
...@@ -42,9 +42,6 @@ from syntaxnet import sentence_pb2 ...@@ -42,9 +42,6 @@ from syntaxnet import sentence_pb2
from syntaxnet.ops import gen_parser_ops from syntaxnet.ops import gen_parser_ops
from syntaxnet.util import check from syntaxnet.util import check
import dragnn.python.load_dragnn_cc_impl
import syntaxnet.load_parser_ops
flags = tf.app.flags flags = tf.app.flags
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
......
...@@ -42,9 +42,6 @@ from dragnn.python import lexicon ...@@ -42,9 +42,6 @@ from dragnn.python import lexicon
from dragnn.python import spec_builder from dragnn.python import spec_builder
from dragnn.python import trainer_lib from dragnn.python import trainer_lib
import dragnn.python.load_dragnn_cc_impl
import syntaxnet.load_parser_ops
flags = tf.app.flags flags = tf.app.flags
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
......
...@@ -45,9 +45,6 @@ from dragnn.python import trainer_lib ...@@ -45,9 +45,6 @@ from dragnn.python import trainer_lib
from syntaxnet.util import check from syntaxnet.util import check
import dragnn.python.load_dragnn_cc_impl
import syntaxnet.load_parser_ops
flags = tf.app.flags flags = tf.app.flags
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
......
/**
* Template for node info.
*/
goog.module('nlp.saft.opensource.dragnn.viz.node_info');
import preact from 'preact';
import _ from 'lodash';
const normalCell = {
'border': 0,
'border-collapse': 'separate',
'padding': '2px',
};
/**
* Style definitions which are directly injected (see README.md comments).
*/
const style = {
featuresTable: {
'background-color': 'rgba(255, 255, 255, 0.9)',
'border': '1px solid #dddddd',
'border-spacing': '2px',
'border-collapse': 'separate',
'font-family': 'roboto, helvectica, arial, sans-serif',
// Sometimes state strings (`stateHtml`) get long, and because this is an
// absolutely-positioned box, we need to make them wrap around.
'max-width': '600px',
'position': 'absolute',
},
heading: {
'background-color': '#ebf5fb',
'font-weight': 'bold',
'text-align': 'center',
...normalCell
},
normalCell: normalCell,
featureGroup: (componentColor) => ({
'background-color': componentColor,
'font-weight': 'bold',
...normalCell
}),
normalRow: {
'border': 0,
'border-collapse': 'separate',
},
};
/**
* Creates table rows that negate IPython/Jupyter notebook styling.
*
* @param {?XML|?Array<XML>} children Child nodes. (Recall Preact handles
* null/undefined gracefully).
* @param {!Object} props Any additional properties.
* @return {!XML} React-y element, representing a table row.
*/
const Row = ({children, ...props}) => (
<tr style={style.normalRow} {...props}>{children}</tr>);
/**
* Creates table cells that negate IPython/Jupyter notebook styling.
*
* @param {?XML|?Array<XML>} children Child nodes. (Recall Preact handles
* null/undefined gracefully).
* @param {!Object} props Any additional properties.
* @return {!XML} React-y element, representing a table cell.
*/
const Cell = ({children, ...props}) => (
<td style={style.normalCell} {...props}>{children}</td>);
/**
* Construct a table "multi-row" with a shared "header" cell.
*
* In ASCII-art,
*
* ------------------------------
* | row1
* header | row2
* | row3
* ------------------------------
*
* @param {string} headerText Text for the header cell
* @param {string} headerColor Color of the header cell
* @param {!Array<XML>} rowsCells Row cells (<td> React-y elements).
* @return {!Array<XML>} Array of React-y elements.
*/
const featureGroup = (headerText, headerColor, rowsCells) => {
const headerCell = (
<td rowspan={rowsCells.length} style={style.featureGroup(headerColor)}>
{headerText}
</td>
);
return _.map(rowsCells, (cells, i) => {
return <Row>{i == 0 ? headerCell : null}{cells}</Row>;
});
};
/**
* Mini helper to intersperse line breaks with a list of elements.
*
* This just replicates previous behavior and looks OK; we could also try spans
* with `display: 'block'` or such.
*
* @param {!Array<XML>} elements React-y elements.
* @return {!Array<XML>} React-y elements with line breaks.
*/
const intersperseLineBreaks = (elements) => _.tail(_.flatten(_.map(
elements, (v) => [<br />, v]
)));
export default class NodeInfo extends preact.Component {
/**
* Obligatory Preact render() function.
*
* It might be worthwhile converting some of the intermediate variables into
* stateless functional components, like Cell and Row.
*
* @param {?Object} selected Cytoscape node selected (null if no selection).
* @param {?Object} mousePosition Mouse position, if a node is selected.
* @return {!XML} Preact components to render.
*/
render({selected, mousePosition}) {
const visible = selected != null;
const stateHtml = visible && selected.data('stateInfo');
// Generates elements for fixed features.
const fixedFeatures = visible ? selected.data('fixedFeatures') : [];
const fixedFeatureElements = _.map(fixedFeatures, (feature) => {
if (feature.value_trace.length == 0) {
// Preact will just prune this out.
return null;
} else {
const rowsCells = _.map(feature.value_trace, (value) => {
// Recall `value_name` is a list of strings (representing feature
// values), but this is OK because strings are valid react elements.
const valueCells = intersperseLineBreaks(value.value_name);
return [<Cell>{value.feature_name}</Cell>, <Cell>{valueCells}</Cell>];
});
return featureGroup(feature.name, '#cccccc', _.map(rowsCells));
}
});
/**
* Generates linked feature info from an edge.
*
* @param {!Object} edge Cytoscape JS Element representing a linked feature.
* @return {[XML,XML]} Linked feature information, as table elements.
*/
const linkedFeatureInfoFromEdge = (edge) => {
return [
<Cell>{edge.data('featureName')}</Cell>,
<Cell>
value {edge.data('featureValue')} from
step {edge.source().data('stepIdx')}
</Cell>
];
};
const linkedFeatureElements = _.flatten(
_.map(this.edgeStatesByComponent(), (edges, componentName) => {
// Because edges are generated by `incomers`, it is guaranteed to be
// non-empty.
const color = _.head(edges).source().parent().data('componentColor');
const rowsCells = _.map(edges, linkedFeatureInfoFromEdge);
return featureGroup(componentName, color, rowsCells);
}));
let positionOrHiddenStyle;
if (visible) {
positionOrHiddenStyle = {
left: mousePosition.x + 20,
top: mousePosition.y + 10,
};
} else {
positionOrHiddenStyle = {display: 'none'};
}
return (
<table style={_.defaults(positionOrHiddenStyle, style.featuresTable)}>
<Row>
<td colspan="3" style={style.heading}>State</td>
</Row>
<Row>
<Cell colspan="3">{stateHtml}</Cell>
</Row>
<Row>
<td colspan="3" style={style.heading}>Features</td>
</Row>
{fixedFeatureElements}
{linkedFeatureElements}
</table>
);
}
/**
* Gets a list of incoming edges, grouped by their component name.
*
* @return {!Object<string, !Array<!Object>>} Map from component name to list
* of edges.
*/
edgeStatesByComponent() {
if (this.props.selected == null) {
return [];
}
const incoming = this.props.selected.incomers(); // edges and nodes
return _.groupBy(incoming.edges(), (edge) => edge.source().parent().id());
}
}
# Format: google3/devtools/metadata/metadata.proto (go/google3metadata)
name: "syntaxnet"
# Use "base" template
g3doc {
headerfooter {
path_regexp: ".*\\.md$"
name: "base"
}
navbar_file : "/company/teams/saft/navbar.md"
logo : "/company/teams/saft/images/logo.png"
favicon : "/company/teams/saft/images/saft-favicon.png"
}
teams_product_id: 7805219680
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
### Module `dragnn_ops` ### Module `dragnn_ops`
Defined in Defined in
[`tensorflow/dragnn/python/dragnn_ops.py`](https://github.com/tensorflow/models/blob/master/research/syntaxnet/dragnn/python/dragnn_ops.py). [`tensorflow/dragnn/python/dragnn_ops.py`](https://github.com/tensorflow/models/blob/master/syntaxnet/dragnn/python/dragnn_ops.py).
Groups the DRAGNN TensorFlow ops in one module. Groups the DRAGNN TensorFlow ops in one module.
......
# Module: dragnn_ops.google3
### Module `dragnn_ops.google3`
This is the root of the google3 tree.
Code in here is built by the Google3 build system.
## Members
...@@ -87,7 +87,8 @@ Note that `stack` here means "words we have already tagged." Thus, this feature ...@@ -87,7 +87,8 @@ Note that `stack` here means "words we have already tagged." Thus, this feature
spec uses three types of features: words, suffixes, and prefixes. The features spec uses three types of features: words, suffixes, and prefixes. The features
are grouped into blocks that share an embedding matrix, concatenated together, are grouped into blocks that share an embedding matrix, concatenated together,
and fed into a chain of hidden layers. This structure is based upon the model and fed into a chain of hidden layers. This structure is based upon the model
proposed by [Chen and Manning (2014)](http://cs.stanford.edu/people/danqi/papers/emnlp2014.pdf). proposed by [Chen and Manning (2014)]
(http://cs.stanford.edu/people/danqi/papers/emnlp2014.pdf).
We show this layout in the schematic below: the state of the system (a stack and We show this layout in the schematic below: the state of the system (a stack and
a buffer, visualized below for both the POS and the dependency parsing task) is a buffer, visualized below for both the POS and the dependency parsing task) is
......
...@@ -331,6 +331,8 @@ cc_library( ...@@ -331,6 +331,8 @@ cc_library(
"binary_segment_state.cc", "binary_segment_state.cc",
"binary_segment_transitions.cc", "binary_segment_transitions.cc",
"char_shift_transitions.cc", "char_shift_transitions.cc",
"head_label_transitions.cc",
"head_label_transitions.h",
"head_transitions.cc", "head_transitions.cc",
"head_transitions.h", "head_transitions.h",
"label_transitions.cc", "label_transitions.cc",
...@@ -353,6 +355,7 @@ cc_library( ...@@ -353,6 +355,7 @@ cc_library(
deps = [ deps = [
":base", ":base",
":feature_extractor", ":feature_extractor",
":generic_features",
":morphology_label_set", ":morphology_label_set",
":registry", ":registry",
":segmenter_utils", ":segmenter_utils",
...@@ -399,6 +402,16 @@ cc_library( ...@@ -399,6 +402,16 @@ cc_library(
], ],
) )
cc_library(
name = "generic_features",
srcs = ["generic_features.cc"],
hdrs = ["generic_features.h"],
deps = [
":feature_extractor",
":registry",
],
)
cc_library( cc_library(
name = "sentence_batch", name = "sentence_batch",
srcs = ["sentence_batch.cc"], srcs = ["sentence_batch.cc"],
...@@ -510,6 +523,7 @@ filegroup( ...@@ -510,6 +523,7 @@ filegroup(
srcs = [ srcs = [
"testdata/context.pbtxt", "testdata/context.pbtxt",
"testdata/document", "testdata/document",
"testdata/hello.txt",
"testdata/mini-training-set", "testdata/mini-training-set",
], ],
) )
...@@ -572,6 +586,17 @@ cc_test( ...@@ -572,6 +586,17 @@ cc_test(
], ],
) )
cc_test(
name = "generic_features_test",
srcs = ["generic_features_test.cc"],
deps = [
":generic_features",
":registry",
":task_context",
":test_main",
],
)
cc_test( cc_test(
name = "sentence_features_test", name = "sentence_features_test",
size = "medium", size = "medium",
...@@ -712,6 +737,20 @@ cc_test( ...@@ -712,6 +737,20 @@ cc_test(
], ],
) )
cc_test(
name = "head_label_transitions_test",
size = "small",
srcs = ["head_label_transitions_test.cc"],
deps = [
":base",
":parser_transitions",
":sentence_proto",
":task_context",
":term_frequency_map",
":test_main",
],
)
cc_test( cc_test(
name = "parser_features_test", name = "parser_features_test",
size = "small", size = "small",
...@@ -750,8 +789,8 @@ py_library( ...@@ -750,8 +789,8 @@ py_library(
name = "syntaxnet_ops", name = "syntaxnet_ops",
srcs = ["syntaxnet_ops.py"], srcs = ["syntaxnet_ops.py"],
deps = [ deps = [
":parser_ops",
":load_parser_ops_py", ":load_parser_ops_py",
":parser_ops",
], ],
) )
......
...@@ -269,7 +269,9 @@ class ArcStandardTransitionSystem : public ParserTransitionSystem { ...@@ -269,7 +269,9 @@ class ArcStandardTransitionSystem : public ParserTransitionSystem {
void PerformRightArc(ParserState *state, int label) const { void PerformRightArc(ParserState *state, int label) const {
DCHECK(IsAllowedRightArc(*state)); DCHECK(IsAllowedRightArc(*state));
int s0 = state->Pop(); int s0 = state->Pop();
state->AddArc(s0, state->Top(), label); int s1 = state->Pop();
state->AddArc(s0, s1, label);
state->Push(s1);
} }
// We are in a deterministic state when we either reached the end of the input // We are in a deterministic state when we either reached the end of the input
......
...@@ -21,6 +21,7 @@ limitations under the License. ...@@ -21,6 +21,7 @@ limitations under the License.
#include <unordered_map> #include <unordered_map>
#include <unordered_set> #include <unordered_set>
#include <vector> #include <vector>
#include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/lib/strings/stringprintf.h"
...@@ -30,11 +31,14 @@ limitations under the License. ...@@ -30,11 +31,14 @@ limitations under the License.
using tensorflow::int8;
using tensorflow::int16;
using tensorflow::int32; using tensorflow::int32;
using tensorflow::int64; using tensorflow::int64;
using tensorflow::uint8;
using tensorflow::uint16;
using tensorflow::uint64; using tensorflow::uint64;
using tensorflow::uint32; using tensorflow::uint32;
using tensorflow::uint32;
using tensorflow::protobuf::TextFormat; using tensorflow::protobuf::TextFormat;
using tensorflow::mutex_lock; using tensorflow::mutex_lock;
using tensorflow::mutex; using tensorflow::mutex;
...@@ -48,6 +52,7 @@ typedef signed int char32; ...@@ -48,6 +52,7 @@ typedef signed int char32;
using tensorflow::StringPiece; using tensorflow::StringPiece;
using std::string; using std::string;
// namespace syntaxnet // namespace syntaxnet
#endif // SYNTAXNET_BASE_H_ #endif // SYNTAXNET_BASE_H_
...@@ -28,8 +28,7 @@ namespace syntaxnet { ...@@ -28,8 +28,7 @@ namespace syntaxnet {
// -MERGE: adds the token at state.input to its prevous word, and also advances // -MERGE: adds the token at state.input to its prevous word, and also advances
// state.input. // state.input.
// //
// Also see nlp/saft/components/segmentation/transition/binary-segment-state.h // Also see binary_segment_state.h for examples on handling spaces.
// for examples on handling spaces.
class BinarySegmentTransitionSystem : public ParserTransitionSystem { class BinarySegmentTransitionSystem : public ParserTransitionSystem {
public: public:
BinarySegmentTransitionSystem() {} BinarySegmentTransitionSystem() {}
......
...@@ -203,8 +203,8 @@ void CharProperty::AddAsciiPredicate(AsciiPredicate *pred) { ...@@ -203,8 +203,8 @@ void CharProperty::AddAsciiPredicate(AsciiPredicate *pred) {
void CharProperty::AddCharProperty(const char *propname) { void CharProperty::AddCharProperty(const char *propname) {
const CharProperty *prop = CharProperty::Lookup(propname); const CharProperty *prop = CharProperty::Lookup(propname);
CHECK(prop != NULL) << ": unknown char property \"" << propname CHECK(prop != nullptr) << ": unknown char property \"" << propname << "\" in "
<< "\" in " << name_; << name_;
int c = -1; int c = -1;
while ((c = prop->NextElementAfter(c)) >= 0) { while ((c = prop->NextElementAfter(c)) >= 0) {
AddChar(c); AddChar(c);
...@@ -268,10 +268,10 @@ const CharProperty *CharProperty::Lookup(const char *subclass) { ...@@ -268,10 +268,10 @@ const CharProperty *CharProperty::Lookup(const char *subclass) {
// the CharProperty it provides. // the CharProperty it provides.
std::unique_ptr<CharPropertyWrapper> wrapper( std::unique_ptr<CharPropertyWrapper> wrapper(
CharPropertyWrapper::Create(subclass)); CharPropertyWrapper::Create(subclass));
if (wrapper.get() == NULL) { if (wrapper == nullptr) {
LOG(ERROR) << "CharPropertyWrapper not found for subclass: " LOG(ERROR) << "CharPropertyWrapper not found for subclass: "
<< "\"" << subclass << "\""; << "\"" << subclass << "\"";
return NULL; return nullptr;
} }
return wrapper->GetCharProperty(); return wrapper->GetCharProperty();
} }
......
...@@ -357,6 +357,8 @@ DECLARE_CHAR_PROPERTY(directional_formatting_code); ...@@ -357,6 +357,8 @@ DECLARE_CHAR_PROPERTY(directional_formatting_code);
// just those listed in our code. See the definitions in char_properties.cc. // just those listed in our code. See the definitions in char_properties.cc.
DECLARE_CHAR_PROPERTY(punctuation_or_symbol); DECLARE_CHAR_PROPERTY(punctuation_or_symbol);
DECLARE_SYNTAXNET_CLASS_REGISTRY("char property wrapper", CharPropertyWrapper);
} // namespace syntaxnet } // namespace syntaxnet
#endif // SYNTAXNET_CHAR_PROPERTIES_H_ #endif // SYNTAXNET_CHAR_PROPERTIES_H_
...@@ -118,6 +118,7 @@ class CharShiftTransitionTest : public ::testing::Test { ...@@ -118,6 +118,7 @@ class CharShiftTransitionTest : public ::testing::Test {
protected: protected:
string MultiFeatureString(const FeatureVector &result) { string MultiFeatureString(const FeatureVector &result) {
std::vector<string> values; std::vector<string> values;
values.reserve(result.size());
for (int i = 0; i < result.size(); ++i) { for (int i = 0; i < result.size(); ++i) {
values.push_back(result.type(i)->GetFeatureValueName(result.value(i))); values.push_back(result.type(i)->GetFeatureValueName(result.value(i)));
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment