Unverified Commit 80178fc6 authored by Mark Omernick's avatar Mark Omernick Committed by GitHub
Browse files

Merge pull request #4153 from terryykoo/master

Export @195097388.
parents a84e1ef9 edea2b67
...@@ -26,6 +26,7 @@ import os ...@@ -26,6 +26,7 @@ import os
import re import re
import time import time
from absl import flags
import tensorflow as tf import tensorflow as tf
from google.protobuf import text_format from google.protobuf import text_format
...@@ -39,7 +40,6 @@ from dragnn.python import sentence_io ...@@ -39,7 +40,6 @@ from dragnn.python import sentence_io
from dragnn.python import spec_builder from dragnn.python import spec_builder
from syntaxnet import sentence_pb2 from syntaxnet import sentence_pb2
flags = tf.app.flags
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
flags.DEFINE_string('master_spec', '', flags.DEFINE_string('master_spec', '',
......
...@@ -19,6 +19,7 @@ r"""Runs a both a segmentation and parsing model on a CoNLL dataset. ...@@ -19,6 +19,7 @@ r"""Runs a both a segmentation and parsing model on a CoNLL dataset.
import re import re
import time import time
from absl import flags
import tensorflow as tf import tensorflow as tf
from google.protobuf import text_format from google.protobuf import text_format
...@@ -34,7 +35,6 @@ from syntaxnet import sentence_pb2 ...@@ -34,7 +35,6 @@ from syntaxnet import sentence_pb2
from syntaxnet.ops import gen_parser_ops from syntaxnet.ops import gen_parser_ops
from syntaxnet.util import check from syntaxnet.util import check
flags = tf.app.flags
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
......
...@@ -42,6 +42,8 @@ import ast ...@@ -42,6 +42,8 @@ import ast
import collections import collections
import os import os
import os.path import os.path
from absl import app
from absl import flags
import tensorflow as tf import tensorflow as tf
from google.protobuf import text_format from google.protobuf import text_format
...@@ -55,7 +57,6 @@ from dragnn.python import trainer_lib ...@@ -55,7 +57,6 @@ from dragnn.python import trainer_lib
from syntaxnet.ops import gen_parser_ops from syntaxnet.ops import gen_parser_ops
from syntaxnet.util import check from syntaxnet.util import check
flags = tf.app.flags
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
flags.DEFINE_string('tf_master', '', flags.DEFINE_string('tf_master', '',
...@@ -191,4 +192,4 @@ def main(unused_argv): ...@@ -191,4 +192,4 @@ def main(unused_argv):
if __name__ == '__main__': if __name__ == '__main__':
tf.app.run() app.run(main)
# Copyright 2017 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
r"""Runs a both a segmentation and parsing model on a CoNLL dataset.
"""
import re
import time
import tensorflow as tf
from google.protobuf import text_format
from tensorflow.python.client import timeline
from tensorflow.python.platform import gfile
from dragnn.protos import spec_pb2
from dragnn.python import graph_builder
from dragnn.python import sentence_io
from dragnn.python import spec_builder
from syntaxnet import sentence_pb2
from syntaxnet.ops import gen_parser_ops
from syntaxnet.util import check
flags = tf.app.flags
FLAGS = flags.FLAGS
flags.DEFINE_string('parser_master_spec', '',
'Path to text file containing a DRAGNN master spec to run.')
flags.DEFINE_string('parser_checkpoint_file', '',
'Path to trained model checkpoint.')
flags.DEFINE_string('parser_resource_dir', '',
'Optional base directory for resources in the master spec.')
flags.DEFINE_string('segmenter_master_spec', '',
'Path to text file containing a DRAGNN master spec to run.')
flags.DEFINE_string('segmenter_checkpoint_file', '',
'Path to trained model checkpoint.')
flags.DEFINE_string('segmenter_resource_dir', '',
'Optional base directory for resources in the master spec.')
flags.DEFINE_bool('complete_master_spec', True, 'Whether the master_specs '
'needs the lexicon and other resources added to them.')
flags.DEFINE_string('input_file', '',
'File of CoNLL-formatted sentences to read from.')
flags.DEFINE_string('output_file', '',
'File path to write annotated sentences to.')
flags.DEFINE_integer('max_batch_size', 2048, 'Maximum batch size to support.')
flags.DEFINE_string('inference_beam_size', '', 'Comma separated list of '
'component_name=beam_size pairs.')
flags.DEFINE_string('locally_normalize', '', 'Comma separated list of '
'component names to do local normalization on.')
flags.DEFINE_integer('threads', 10, 'Number of threads used for intra- and '
'inter-op parallelism.')
flags.DEFINE_string('timeline_output_file', '', 'Path to save timeline to. '
'If specified, the final iteration of the evaluation loop '
'will capture and save a TensorFlow timeline.')
flags.DEFINE_bool('use_gold_segmentation', False,
'Whether or not to use gold segmentation.')
def main(unused_argv):
# Parse the flags containint lists, using regular expressions.
# This matches and extracts key=value pairs.
component_beam_sizes = re.findall(r'([^=,]+)=(\d+)',
FLAGS.inference_beam_size)
# This matches strings separated by a comma. Does not return any empty
# strings.
components_to_locally_normalize = re.findall(r'[^,]+',
FLAGS.locally_normalize)
## SEGMENTATION ##
if not FLAGS.use_gold_segmentation:
# Reads master spec.
master_spec = spec_pb2.MasterSpec()
with gfile.FastGFile(FLAGS.segmenter_master_spec) as fin:
text_format.Parse(fin.read(), master_spec)
if FLAGS.complete_master_spec:
spec_builder.complete_master_spec(
master_spec, None, FLAGS.segmenter_resource_dir)
# Graph building.
tf.logging.info('Building the graph')
g = tf.Graph()
with g.as_default(), tf.device('/device:CPU:0'):
hyperparam_config = spec_pb2.GridPoint()
hyperparam_config.use_moving_average = True
builder = graph_builder.MasterBuilder(master_spec, hyperparam_config)
annotator = builder.add_annotation()
builder.add_saver()
tf.logging.info('Reading documents...')
input_corpus = sentence_io.ConllSentenceReader(FLAGS.input_file).corpus()
with tf.Session(graph=tf.Graph()) as tmp_session:
char_input = gen_parser_ops.char_token_generator(input_corpus)
char_corpus = tmp_session.run(char_input)
check.Eq(len(input_corpus), len(char_corpus))
session_config = tf.ConfigProto(
log_device_placement=False,
intra_op_parallelism_threads=FLAGS.threads,
inter_op_parallelism_threads=FLAGS.threads)
with tf.Session(graph=g, config=session_config) as sess:
tf.logging.info('Initializing variables...')
sess.run(tf.global_variables_initializer())
tf.logging.info('Loading from checkpoint...')
sess.run('save/restore_all',
{'save/Const:0': FLAGS.segmenter_checkpoint_file})
tf.logging.info('Processing sentences...')
processed = []
start_time = time.time()
run_metadata = tf.RunMetadata()
for start in range(0, len(char_corpus), FLAGS.max_batch_size):
end = min(start + FLAGS.max_batch_size, len(char_corpus))
feed_dict = {annotator['input_batch']: char_corpus[start:end]}
if FLAGS.timeline_output_file and end == len(char_corpus):
serialized_annotations = sess.run(
annotator['annotations'], feed_dict=feed_dict,
options=tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE),
run_metadata=run_metadata)
trace = timeline.Timeline(step_stats=run_metadata.step_stats)
with open(FLAGS.timeline_output_file, 'w') as trace_file:
trace_file.write(trace.generate_chrome_trace_format())
else:
serialized_annotations = sess.run(
annotator['annotations'], feed_dict=feed_dict)
processed.extend(serialized_annotations)
tf.logging.info('Processed %d documents in %.2f seconds.',
len(char_corpus), time.time() - start_time)
input_corpus = processed
else:
input_corpus = sentence_io.ConllSentenceReader(FLAGS.input_file).corpus()
## PARSING
# Reads master spec.
master_spec = spec_pb2.MasterSpec()
with gfile.FastGFile(FLAGS.parser_master_spec) as fin:
text_format.Parse(fin.read(), master_spec)
if FLAGS.complete_master_spec:
spec_builder.complete_master_spec(
master_spec, None, FLAGS.parser_resource_dir)
# Graph building.
tf.logging.info('Building the graph')
g = tf.Graph()
with g.as_default(), tf.device('/device:CPU:0'):
hyperparam_config = spec_pb2.GridPoint()
hyperparam_config.use_moving_average = True
builder = graph_builder.MasterBuilder(master_spec, hyperparam_config)
annotator = builder.add_annotation()
builder.add_saver()
tf.logging.info('Reading documents...')
session_config = tf.ConfigProto(
log_device_placement=False,
intra_op_parallelism_threads=FLAGS.threads,
inter_op_parallelism_threads=FLAGS.threads)
with tf.Session(graph=g, config=session_config) as sess:
tf.logging.info('Initializing variables...')
sess.run(tf.global_variables_initializer())
tf.logging.info('Loading from checkpoint...')
sess.run('save/restore_all', {'save/Const:0': FLAGS.parser_checkpoint_file})
tf.logging.info('Processing sentences...')
processed = []
start_time = time.time()
run_metadata = tf.RunMetadata()
for start in range(0, len(input_corpus), FLAGS.max_batch_size):
end = min(start + FLAGS.max_batch_size, len(input_corpus))
feed_dict = {annotator['input_batch']: input_corpus[start:end]}
for comp, beam_size in component_beam_sizes:
feed_dict['%s/InferenceBeamSize:0' % comp] = beam_size
for comp in components_to_locally_normalize:
feed_dict['%s/LocallyNormalize:0' % comp] = True
if FLAGS.timeline_output_file and end == len(input_corpus):
serialized_annotations = sess.run(
annotator['annotations'], feed_dict=feed_dict,
options=tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE),
run_metadata=run_metadata)
trace = timeline.Timeline(step_stats=run_metadata.step_stats)
with open(FLAGS.timeline_output_file, 'w') as trace_file:
trace_file.write(trace.generate_chrome_trace_format())
else:
serialized_annotations = sess.run(
annotator['annotations'], feed_dict=feed_dict)
processed.extend(serialized_annotations)
tf.logging.info('Processed %d documents in %.2f seconds.',
len(input_corpus), time.time() - start_time)
if FLAGS.output_file:
with gfile.GFile(FLAGS.output_file, 'w') as f:
for serialized_sentence in processed:
sentence = sentence_pb2.Sentence()
sentence.ParseFromString(serialized_sentence)
f.write('#' + sentence.text.encode('utf-8') + '\n')
for i, token in enumerate(sentence.token):
head = token.head + 1
f.write('%s\t%s\t_\t_\t_\t_\t%d\t%s\t_\t_\n'%(
i + 1,
token.word.encode('utf-8'), head,
token.label.encode('utf-8')))
f.write('\n\n')
if __name__ == '__main__':
tf.app.run()
...@@ -17,6 +17,7 @@ r"""Runs a both a segmentation and parsing model on a CoNLL dataset. ...@@ -17,6 +17,7 @@ r"""Runs a both a segmentation and parsing model on a CoNLL dataset.
import re import re
import time import time
from absl import flags
import tensorflow as tf import tensorflow as tf
from tensorflow.python.client import timeline from tensorflow.python.client import timeline
...@@ -35,7 +36,6 @@ from syntaxnet import syntaxnet_ops ...@@ -35,7 +36,6 @@ from syntaxnet import syntaxnet_ops
from syntaxnet.ops import gen_parser_ops from syntaxnet.ops import gen_parser_ops
from syntaxnet.util import check from syntaxnet.util import check
flags = tf.app.flags
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
flags.DEFINE_string( flags.DEFINE_string(
......
...@@ -21,6 +21,8 @@ import os ...@@ -21,6 +21,8 @@ import os
import os.path import os.path
import random import random
import time import time
from absl import app
from absl import flags
import tensorflow as tf import tensorflow as tf
from tensorflow.python.platform import gfile from tensorflow.python.platform import gfile
...@@ -40,7 +42,6 @@ from dragnn.python import sentence_io ...@@ -40,7 +42,6 @@ from dragnn.python import sentence_io
from dragnn.python import spec_builder from dragnn.python import spec_builder
from dragnn.python import trainer_lib from dragnn.python import trainer_lib
flags = tf.app.flags
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
flags.DEFINE_string('tf_master', '', flags.DEFINE_string('tf_master', '',
...@@ -189,4 +190,4 @@ def main(unused_argv): ...@@ -189,4 +190,4 @@ def main(unused_argv):
if __name__ == '__main__': if __name__ == '__main__':
tf.app.run() app.run(main)
...@@ -27,6 +27,7 @@ import os ...@@ -27,6 +27,7 @@ import os
import re import re
import time import time
from absl import flags
import tensorflow as tf import tensorflow as tf
from google.protobuf import text_format from google.protobuf import text_format
...@@ -42,7 +43,6 @@ from syntaxnet import sentence_pb2 ...@@ -42,7 +43,6 @@ from syntaxnet import sentence_pb2
from syntaxnet.ops import gen_parser_ops from syntaxnet.ops import gen_parser_ops
from syntaxnet.util import check from syntaxnet.util import check
flags = tf.app.flags
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
flags.DEFINE_string('master_spec', '', flags.DEFINE_string('master_spec', '',
......
...@@ -22,6 +22,7 @@ import os ...@@ -22,6 +22,7 @@ import os
import os.path import os.path
import random import random
import time import time
from absl import flags
import tensorflow as tf import tensorflow as tf
from tensorflow.python.platform import gfile from tensorflow.python.platform import gfile
...@@ -42,7 +43,6 @@ from dragnn.python import lexicon ...@@ -42,7 +43,6 @@ from dragnn.python import lexicon
from dragnn.python import spec_builder from dragnn.python import spec_builder
from dragnn.python import trainer_lib from dragnn.python import trainer_lib
flags = tf.app.flags
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
flags.DEFINE_string('tf_master', '', flags.DEFINE_string('tf_master', '',
......
...@@ -22,6 +22,7 @@ import os ...@@ -22,6 +22,7 @@ import os
import os.path import os.path
import random import random
import time import time
from absl import flags
import tensorflow as tf import tensorflow as tf
from tensorflow.python.framework import errors from tensorflow.python.framework import errors
...@@ -45,7 +46,6 @@ from dragnn.python import trainer_lib ...@@ -45,7 +46,6 @@ from dragnn.python import trainer_lib
from syntaxnet.util import check from syntaxnet.util import check
flags = tf.app.flags
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
flags.DEFINE_string('tf_master', '', flags.DEFINE_string('tf_master', '',
......
/**
* Template for node info.
*/
goog.module('nlp.saft.opensource.dragnn.viz.node_info');
import preact from 'preact';
import _ from 'lodash';
const normalCell = {
'border': 0,
'border-collapse': 'separate',
'padding': '2px',
};
/**
* Style definitions which are directly injected (see README.md comments).
*/
const style = {
featuresTable: {
'background-color': 'rgba(255, 255, 255, 0.9)',
'border': '1px solid #dddddd',
'border-spacing': '2px',
'border-collapse': 'separate',
'font-family': 'roboto, helvectica, arial, sans-serif',
// Sometimes state strings (`stateHtml`) get long, and because this is an
// absolutely-positioned box, we need to make them wrap around.
'max-width': '600px',
'position': 'absolute',
},
heading: {
'background-color': '#ebf5fb',
'font-weight': 'bold',
'text-align': 'center',
...normalCell
},
normalCell: normalCell,
featureGroup: (componentColor) => ({
'background-color': componentColor,
'font-weight': 'bold',
...normalCell
}),
normalRow: {
'border': 0,
'border-collapse': 'separate',
},
};
/**
* Creates table rows that negate IPython/Jupyter notebook styling.
*
* @param {?XML|?Array<XML>} children Child nodes. (Recall Preact handles
* null/undefined gracefully).
* @param {!Object} props Any additional properties.
* @return {!XML} React-y element, representing a table row.
*/
const Row = ({children, ...props}) => (
<tr style={style.normalRow} {...props}>{children}</tr>);
/**
* Creates table cells that negate IPython/Jupyter notebook styling.
*
* @param {?XML|?Array<XML>} children Child nodes. (Recall Preact handles
* null/undefined gracefully).
* @param {!Object} props Any additional properties.
* @return {!XML} React-y element, representing a table cell.
*/
const Cell = ({children, ...props}) => (
<td style={style.normalCell} {...props}>{children}</td>);
/**
* Construct a table "multi-row" with a shared "header" cell.
*
* In ASCII-art,
*
* ------------------------------
* | row1
* header | row2
* | row3
* ------------------------------
*
* @param {string} headerText Text for the header cell
* @param {string} headerColor Color of the header cell
* @param {!Array<XML>} rowsCells Row cells (<td> React-y elements).
* @return {!Array<XML>} Array of React-y elements.
*/
const featureGroup = (headerText, headerColor, rowsCells) => {
const headerCell = (
<td rowspan={rowsCells.length} style={style.featureGroup(headerColor)}>
{headerText}
</td>
);
return _.map(rowsCells, (cells, i) => {
return <Row>{i == 0 ? headerCell : null}{cells}</Row>;
});
};
/**
* Mini helper to intersperse line breaks with a list of elements.
*
* This just replicates previous behavior and looks OK; we could also try spans
* with `display: 'block'` or such.
*
* @param {!Array<XML>} elements React-y elements.
* @return {!Array<XML>} React-y elements with line breaks.
*/
const intersperseLineBreaks = (elements) => _.tail(_.flatten(_.map(
elements, (v) => [<br />, v]
)));
export default class NodeInfo extends preact.Component {
/**
* Obligatory Preact render() function.
*
* It might be worthwhile converting some of the intermediate variables into
* stateless functional components, like Cell and Row.
*
* @param {?Object} selected Cytoscape node selected (null if no selection).
* @param {?Object} mousePosition Mouse position, if a node is selected.
* @return {!XML} Preact components to render.
*/
render({selected, mousePosition}) {
const visible = selected != null;
const stateHtml = visible && selected.data('stateInfo');
// Generates elements for fixed features.
const fixedFeatures = visible ? selected.data('fixedFeatures') : [];
const fixedFeatureElements = _.map(fixedFeatures, (feature) => {
if (feature.value_trace.length == 0) {
// Preact will just prune this out.
return null;
} else {
const rowsCells = _.map(feature.value_trace, (value) => {
// Recall `value_name` is a list of strings (representing feature
// values), but this is OK because strings are valid react elements.
const valueCells = intersperseLineBreaks(value.value_name);
return [<Cell>{value.feature_name}</Cell>, <Cell>{valueCells}</Cell>];
});
return featureGroup(feature.name, '#cccccc', _.map(rowsCells));
}
});
/**
* Generates linked feature info from an edge.
*
* @param {!Object} edge Cytoscape JS Element representing a linked feature.
* @return {[XML,XML]} Linked feature information, as table elements.
*/
const linkedFeatureInfoFromEdge = (edge) => {
return [
<Cell>{edge.data('featureName')}</Cell>,
<Cell>
value {edge.data('featureValue')} from
step {edge.source().data('stepIdx')}
</Cell>
];
};
const linkedFeatureElements = _.flatten(
_.map(this.edgeStatesByComponent(), (edges, componentName) => {
// Because edges are generated by `incomers`, it is guaranteed to be
// non-empty.
const color = _.head(edges).source().parent().data('componentColor');
const rowsCells = _.map(edges, linkedFeatureInfoFromEdge);
return featureGroup(componentName, color, rowsCells);
}));
let positionOrHiddenStyle;
if (visible) {
positionOrHiddenStyle = {
left: mousePosition.x + 20,
top: mousePosition.y + 10,
};
} else {
positionOrHiddenStyle = {display: 'none'};
}
return (
<table style={_.defaults(positionOrHiddenStyle, style.featuresTable)}>
<Row>
<td colspan="3" style={style.heading}>State</td>
</Row>
<Row>
<Cell colspan="3">{stateHtml}</Cell>
</Row>
<Row>
<td colspan="3" style={style.heading}>Features</td>
</Row>
{fixedFeatureElements}
{linkedFeatureElements}
</table>
);
}
/**
* Gets a list of incoming edges, grouped by their component name.
*
* @return {!Object<string, !Array<!Object>>} Map from component name to list
* of edges.
*/
edgeStatesByComponent() {
if (this.props.selected == null) {
return [];
}
const incoming = this.props.selected.incomers(); // edges and nodes
return _.groupBy(incoming.edges(), (edge) => edge.source().parent().id());
}
}
...@@ -17,7 +17,7 @@ py_library( ...@@ -17,7 +17,7 @@ py_library(
deps = [ deps = [
"//dragnn/core:dragnn_bulk_ops", "//dragnn/core:dragnn_bulk_ops",
"//dragnn/core:dragnn_ops", "//dragnn/core:dragnn_ops",
"//dragnn/protos:spec_py_pb2", "//dragnn/protos:spec_pb2_py",
"//dragnn/python:graph_builder", "//dragnn/python:graph_builder",
"//dragnn/python:lexicon", "//dragnn/python:lexicon",
"//dragnn/python:load_dragnn_cc_impl_py", "//dragnn/python:load_dragnn_cc_impl_py",
...@@ -25,7 +25,7 @@ py_library( ...@@ -25,7 +25,7 @@ py_library(
"//dragnn/python:visualization", "//dragnn/python:visualization",
"//syntaxnet:load_parser_ops_py", "//syntaxnet:load_parser_ops_py",
"//syntaxnet:parser_ops", "//syntaxnet:parser_ops",
"//syntaxnet:sentence_py_pb2", "//syntaxnet:sentence_pb2_py",
"@org_tensorflow//tensorflow:tensorflow_py", "@org_tensorflow//tensorflow:tensorflow_py",
"@org_tensorflow//tensorflow/core:protos_all_py", "@org_tensorflow//tensorflow/core:protos_all_py",
], ],
...@@ -34,6 +34,7 @@ py_library( ...@@ -34,6 +34,7 @@ py_library(
filegroup( filegroup(
name = "data", name = "data",
data = glob(["tutorial_data/*"]), data = glob(["tutorial_data/*"]),
visibility = ["//visibility:public"],
) )
sh_test( sh_test(
......
...@@ -11,63 +11,66 @@ package( ...@@ -11,63 +11,66 @@ package(
licenses(["notice"]) # Apache 2.0 licenses(["notice"]) # Apache 2.0
load( load(
"syntaxnet", "@org_tensorflow//tensorflow:tensorflow.bzl",
"tf_proto_library",
"tf_proto_library_py",
"tf_gen_op_libs", "tf_gen_op_libs",
"tf_gen_op_wrapper_py", "tf_gen_op_wrapper_py",
) )
load(
":syntaxnet.bzl",
"tf_proto_library_cc",
"tf_proto_library_py",
)
# proto libraries # proto libraries
tf_proto_library( tf_proto_library_cc(
name = "feature_extractor_proto", name = "feature_extractor_proto",
srcs = ["feature_extractor.proto"], srcs = ["feature_extractor.proto"],
) )
tf_proto_library( tf_proto_library_cc(
name = "sentence_proto", name = "sentence_proto",
srcs = ["sentence.proto"], srcs = ["sentence.proto"],
) )
tf_proto_library_py( tf_proto_library_py(
name = "sentence_py_pb2", name = "sentence_pb2",
srcs = ["sentence.proto"], srcs = ["sentence.proto"],
) )
tf_proto_library( tf_proto_library_cc(
name = "dictionary_proto", name = "dictionary_proto",
srcs = ["dictionary.proto"], srcs = ["dictionary.proto"],
) )
tf_proto_library_py( tf_proto_library_py(
name = "dictionary_py_pb2", name = "dictionary_pb2",
srcs = ["dictionary.proto"], srcs = ["dictionary.proto"],
) )
tf_proto_library( tf_proto_library_cc(
name = "kbest_syntax_proto", name = "kbest_syntax_proto",
srcs = ["kbest_syntax.proto"], srcs = ["kbest_syntax.proto"],
deps = [":sentence_proto"], protodeps = [":sentence_proto"],
) )
tf_proto_library( tf_proto_library_cc(
name = "task_spec_proto", name = "task_spec_proto",
srcs = ["task_spec.proto"], srcs = ["task_spec.proto"],
) )
tf_proto_library_py( tf_proto_library_py(
name = "task_spec_py_pb2", name = "task_spec_pb2",
srcs = ["task_spec.proto"], srcs = ["task_spec.proto"],
) )
tf_proto_library( tf_proto_library_cc(
name = "sparse_proto", name = "sparse_proto",
srcs = ["sparse.proto"], srcs = ["sparse.proto"],
) )
tf_proto_library_py( tf_proto_library_py(
name = "sparse_py_pb2", name = "sparse_pb2",
srcs = ["sparse.proto"], srcs = ["sparse.proto"],
) )
...@@ -79,11 +82,10 @@ cc_library( ...@@ -79,11 +82,10 @@ cc_library(
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
deps = [ deps = [
"@com_googlesource_code_re2//:re2", "@com_googlesource_code_re2//:re2",
"@protobuf_archive//:protobuf",
"@org_tensorflow//third_party/eigen3", "@org_tensorflow//third_party/eigen3",
] + select({ ] + select({
"//conditions:default": [ "//conditions:default": [
"@org_tensorflow//tensorflow/core:framework", "@org_tensorflow//tensorflow/core:framework_headers_lib",
"@org_tensorflow//tensorflow/core:lib", "@org_tensorflow//tensorflow/core:lib",
], ],
"@org_tensorflow//tensorflow:darwin": [ "@org_tensorflow//tensorflow:darwin": [
...@@ -122,7 +124,7 @@ cc_library( ...@@ -122,7 +124,7 @@ cc_library(
hdrs = ["document_format.h"], hdrs = ["document_format.h"],
deps = [ deps = [
":registry", ":registry",
":sentence_proto", ":sentence_proto_cc",
":task_context", ":task_context",
], ],
) )
...@@ -134,7 +136,7 @@ cc_library( ...@@ -134,7 +136,7 @@ cc_library(
":base", ":base",
":document_format", ":document_format",
":segmenter_utils", ":segmenter_utils",
":sentence_proto", ":sentence_proto_cc",
], ],
alwayslink = 1, alwayslink = 1,
) )
...@@ -144,7 +146,7 @@ cc_library( ...@@ -144,7 +146,7 @@ cc_library(
srcs = ["fml_parser.cc"], srcs = ["fml_parser.cc"],
hdrs = ["fml_parser.h"], hdrs = ["fml_parser.h"],
deps = [ deps = [
":feature_extractor_proto", ":feature_extractor_proto_cc",
":utils", ":utils",
], ],
) )
...@@ -153,9 +155,9 @@ cc_library( ...@@ -153,9 +155,9 @@ cc_library(
name = "proto_io", name = "proto_io",
hdrs = ["proto_io.h"], hdrs = ["proto_io.h"],
deps = [ deps = [
":feature_extractor_proto", ":feature_extractor_proto_cc",
":fml_parser", ":fml_parser",
":sentence_proto", ":sentence_proto_cc",
":task_context", ":task_context",
], ],
) )
...@@ -168,6 +170,7 @@ cc_library( ...@@ -168,6 +170,7 @@ cc_library(
":registry", ":registry",
":utils", ":utils",
"//util/utf8:unicodetext", "//util/utf8:unicodetext",
"@com_google_absl//absl/base:core_headers",
], ],
alwayslink = 1, alwayslink = 1,
) )
...@@ -190,7 +193,7 @@ cc_library( ...@@ -190,7 +193,7 @@ cc_library(
deps = [ deps = [
":base", ":base",
":char_properties", ":char_properties",
":sentence_proto", ":sentence_proto_cc",
"//util/utf8:unicodetext", "//util/utf8:unicodetext",
], ],
alwayslink = 1, alwayslink = 1,
...@@ -205,9 +208,9 @@ cc_library( ...@@ -205,9 +208,9 @@ cc_library(
], ],
deps = [ deps = [
":document_format", ":document_format",
":feature_extractor_proto", ":feature_extractor_proto_cc",
":proto_io", ":proto_io",
":sentence_proto", ":sentence_proto_cc",
":task_context", ":task_context",
":utils", ":utils",
":workspace", ":workspace",
...@@ -219,9 +222,9 @@ cc_library( ...@@ -219,9 +222,9 @@ cc_library(
srcs = ["affix.cc"], srcs = ["affix.cc"],
hdrs = ["affix.h"], hdrs = ["affix.h"],
deps = [ deps = [
":dictionary_proto", ":dictionary_proto_cc",
":feature_extractor", ":feature_extractor",
":sentence_proto", ":sentence_proto_cc",
":shared_store", ":shared_store",
":term_frequency_map", ":term_frequency_map",
":utils", ":utils",
...@@ -276,7 +279,9 @@ cc_library( ...@@ -276,7 +279,9 @@ cc_library(
srcs = ["registry.cc"], srcs = ["registry.cc"],
hdrs = ["registry.h"], hdrs = ["registry.h"],
deps = [ deps = [
":base",
":utils", ":utils",
"@org_tensorflow//tensorflow/core:lib",
], ],
) )
...@@ -294,7 +299,7 @@ cc_library( ...@@ -294,7 +299,7 @@ cc_library(
srcs = ["task_context.cc"], srcs = ["task_context.cc"],
hdrs = ["task_context.h"], hdrs = ["task_context.h"],
deps = [ deps = [
":task_spec_proto", ":task_spec_proto_cc",
":utils", ":utils",
], ],
) )
...@@ -307,7 +312,6 @@ cc_library( ...@@ -307,7 +312,6 @@ cc_library(
deps = [ deps = [
":utils", ":utils",
], ],
alwayslink = 1,
) )
cc_library( cc_library(
...@@ -319,7 +323,7 @@ cc_library( ...@@ -319,7 +323,7 @@ cc_library(
":feature_extractor", ":feature_extractor",
":proto_io", ":proto_io",
":registry", ":registry",
":sentence_proto", ":sentence_proto_cc",
":utils", ":utils",
], ],
) )
...@@ -360,7 +364,7 @@ cc_library( ...@@ -360,7 +364,7 @@ cc_library(
":registry", ":registry",
":segmenter_utils", ":segmenter_utils",
":sentence_features", ":sentence_features",
":sentence_proto", ":sentence_proto_cc",
":shared_store", ":shared_store",
":task_context", ":task_context",
":term_frequency_map", ":term_frequency_map",
...@@ -377,10 +381,10 @@ cc_library( ...@@ -377,10 +381,10 @@ cc_library(
srcs = ["populate_test_inputs.cc"], srcs = ["populate_test_inputs.cc"],
hdrs = ["populate_test_inputs.h"], hdrs = ["populate_test_inputs.h"],
deps = [ deps = [
":dictionary_proto", ":dictionary_proto_cc",
":sentence_proto", ":sentence_proto_cc",
":task_context", ":task_context",
":task_spec_proto", ":task_spec_proto_cc",
":term_frequency_map", ":term_frequency_map",
":test_main", ":test_main",
], ],
...@@ -395,7 +399,7 @@ cc_library( ...@@ -395,7 +399,7 @@ cc_library(
":feature_extractor", ":feature_extractor",
":parser_transitions", ":parser_transitions",
":sentence_features", ":sentence_features",
":sparse_proto", ":sparse_proto_cc",
":task_context", ":task_context",
":utils", ":utils",
":workspace", ":workspace",
...@@ -420,10 +424,10 @@ cc_library( ...@@ -420,10 +424,10 @@ cc_library(
":embedding_feature_extractor", ":embedding_feature_extractor",
":feature_extractor", ":feature_extractor",
":parser_transitions", ":parser_transitions",
":sentence_proto", ":sentence_proto_cc",
":sparse_proto", ":sparse_proto_cc",
":task_context", ":task_context",
":task_spec_proto", ":task_spec_proto_cc",
":term_frequency_map", ":term_frequency_map",
":workspace", ":workspace",
], ],
...@@ -438,10 +442,10 @@ cc_library( ...@@ -438,10 +442,10 @@ cc_library(
deps = [ deps = [
":parser_transitions", ":parser_transitions",
":sentence_batch", ":sentence_batch",
":sentence_proto", ":sentence_proto_cc",
":sparse_proto", ":sparse_proto_cc",
":task_context", ":task_context",
":task_spec_proto", ":task_spec_proto_cc",
], ],
alwayslink = 1, alwayslink = 1,
) )
...@@ -454,7 +458,7 @@ cc_library( ...@@ -454,7 +458,7 @@ cc_library(
":parser_transitions", ":parser_transitions",
":segmenter_utils", ":segmenter_utils",
":sentence_batch", ":sentence_batch",
":sentence_proto", ":sentence_proto_cc",
":task_context", ":task_context",
":text_formats", ":text_formats",
], ],
...@@ -472,7 +476,7 @@ cc_library( ...@@ -472,7 +476,7 @@ cc_library(
":parser_transitions", ":parser_transitions",
":segmenter_utils", ":segmenter_utils",
":sentence_batch", ":sentence_batch",
":sentence_proto", ":sentence_proto_cc",
":term_frequency_map", ":term_frequency_map",
":text_formats", ":text_formats",
":utils", ":utils",
...@@ -484,12 +488,20 @@ cc_library( ...@@ -484,12 +488,20 @@ cc_library(
name = "unpack_sparse_features", name = "unpack_sparse_features",
srcs = ["unpack_sparse_features.cc"], srcs = ["unpack_sparse_features.cc"],
deps = [ deps = [
":sparse_proto", ":sparse_proto_cc",
":utils", ":utils",
], ],
alwayslink = 1, alwayslink = 1,
) )
cc_library(
name = "shape_helpers",
hdrs = ["ops/shape_helpers.h"],
deps = [
"@org_tensorflow//tensorflow/core:framework_headers_lib",
],
)
cc_library( cc_library(
name = "parser_ops_cc", name = "parser_ops_cc",
srcs = ["ops/parser_ops.cc"], srcs = ["ops/parser_ops.cc"],
...@@ -498,6 +510,7 @@ cc_library( ...@@ -498,6 +510,7 @@ cc_library(
":document_filters", ":document_filters",
":lexicon_builder", ":lexicon_builder",
":reader_ops", ":reader_ops",
":shape_helpers",
":unpack_sparse_features", ":unpack_sparse_features",
], ],
alwayslink = 1, alwayslink = 1,
...@@ -581,7 +594,7 @@ cc_test( ...@@ -581,7 +594,7 @@ cc_test(
deps = [ deps = [
":base", ":base",
":segmenter_utils", ":segmenter_utils",
":sentence_proto", ":sentence_proto_cc",
":test_main", ":test_main",
], ],
) )
...@@ -605,9 +618,9 @@ cc_test( ...@@ -605,9 +618,9 @@ cc_test(
":feature_extractor", ":feature_extractor",
":populate_test_inputs", ":populate_test_inputs",
":sentence_features", ":sentence_features",
":sentence_proto", ":sentence_proto_cc",
":task_context", ":task_context",
":task_spec_proto", ":task_spec_proto_cc",
":term_frequency_map", ":term_frequency_map",
":test_main", ":test_main",
":utils", ":utils",
...@@ -622,7 +635,7 @@ cc_test( ...@@ -622,7 +635,7 @@ cc_test(
deps = [ deps = [
":feature_extractor", ":feature_extractor",
":parser_transitions", ":parser_transitions",
":sentence_proto", ":sentence_proto_cc",
":task_context", ":task_context",
":term_frequency_map", ":term_frequency_map",
":test_main", ":test_main",
...@@ -648,8 +661,8 @@ cc_test( ...@@ -648,8 +661,8 @@ cc_test(
deps = [ deps = [
":parser_transitions", ":parser_transitions",
":populate_test_inputs", ":populate_test_inputs",
":sentence_proto", ":sentence_proto_cc",
":task_spec_proto", ":task_spec_proto_cc",
":test_main", ":test_main",
], ],
) )
...@@ -662,8 +675,8 @@ cc_test( ...@@ -662,8 +675,8 @@ cc_test(
deps = [ deps = [
":parser_transitions", ":parser_transitions",
":populate_test_inputs", ":populate_test_inputs",
":sentence_proto", ":sentence_proto_cc",
":task_spec_proto", ":task_spec_proto_cc",
":test_main", ":test_main",
], ],
) )
...@@ -674,7 +687,7 @@ cc_test( ...@@ -674,7 +687,7 @@ cc_test(
srcs = ["binary_segment_transitions_test.cc"], srcs = ["binary_segment_transitions_test.cc"],
deps = [ deps = [
":parser_transitions", ":parser_transitions",
":sentence_proto", ":sentence_proto_cc",
":task_context", ":task_context",
":test_main", ":test_main",
":workspace", ":workspace",
...@@ -689,8 +702,8 @@ cc_test( ...@@ -689,8 +702,8 @@ cc_test(
deps = [ deps = [
":parser_transitions", ":parser_transitions",
":populate_test_inputs", ":populate_test_inputs",
":sentence_proto", ":sentence_proto_cc",
":task_spec_proto", ":task_spec_proto_cc",
":test_main", ":test_main",
], ],
) )
...@@ -702,7 +715,7 @@ cc_test( ...@@ -702,7 +715,7 @@ cc_test(
deps = [ deps = [
":base", ":base",
":parser_transitions", ":parser_transitions",
":sentence_proto", ":sentence_proto_cc",
":task_context", ":task_context",
":term_frequency_map", ":term_frequency_map",
":test_main", ":test_main",
...@@ -716,7 +729,7 @@ cc_test( ...@@ -716,7 +729,7 @@ cc_test(
deps = [ deps = [
":base", ":base",
":parser_transitions", ":parser_transitions",
":sentence_proto", ":sentence_proto_cc",
":task_context", ":task_context",
":term_frequency_map", ":term_frequency_map",
":test_main", ":test_main",
...@@ -730,7 +743,7 @@ cc_test( ...@@ -730,7 +743,7 @@ cc_test(
deps = [ deps = [
":base", ":base",
":parser_transitions", ":parser_transitions",
":sentence_proto", ":sentence_proto_cc",
":task_context", ":task_context",
":term_frequency_map", ":term_frequency_map",
":test_main", ":test_main",
...@@ -744,7 +757,7 @@ cc_test( ...@@ -744,7 +757,7 @@ cc_test(
deps = [ deps = [
":base", ":base",
":parser_transitions", ":parser_transitions",
":sentence_proto", ":sentence_proto_cc",
":task_context", ":task_context",
":term_frequency_map", ":term_frequency_map",
":test_main", ":test_main",
...@@ -759,19 +772,69 @@ cc_test( ...@@ -759,19 +772,69 @@ cc_test(
":feature_extractor", ":feature_extractor",
":parser_transitions", ":parser_transitions",
":populate_test_inputs", ":populate_test_inputs",
":sentence_proto", ":sentence_proto_cc",
":task_context", ":task_context",
":task_spec_proto", ":task_spec_proto_cc",
":term_frequency_map", ":term_frequency_map",
":test_main", ":test_main",
":workspace", ":workspace",
], ],
) )
cc_test(
name = "term_frequency_map_test",
size = "small",
srcs = ["term_frequency_map_test.cc"],
deps = [
":base",
":term_frequency_map",
":test_main",
],
)
cc_test(
name = "fml_parser_test",
srcs = ["fml_parser_test.cc"],
deps = [
":base",
":feature_extractor_proto_cc",
":fml_parser",
":test_main",
"@org_tensorflow//tensorflow/core:lib",
"@org_tensorflow//tensorflow/core:test",
],
)
cc_test(
name = "registry_test",
srcs = ["registry_test.cc"],
deps = [
":base",
":registry",
":test_main",
"//dragnn/core/test:generic",
"@org_tensorflow//tensorflow/core:test",
],
)
cc_test(
name = "registry_test_with_duplicate",
srcs = ["registry_test.cc"],
defines = ["DRAGNN_REGISTRY_TEST_WITH_DUPLICATE"],
deps = [
":base",
":registry",
":test_main",
"//dragnn/core/test:generic",
"@org_tensorflow//tensorflow/core:test",
],
)
# py graph builder and trainer # py graph builder and trainer
tf_gen_op_libs( tf_gen_op_libs(
op_lib_names = ["parser_ops"], op_lib_names = ["parser_ops"],
deps = [":shape_helpers"],
) )
tf_gen_op_wrapper_py( tf_gen_op_wrapper_py(
...@@ -819,7 +882,9 @@ py_binary( ...@@ -819,7 +882,9 @@ py_binary(
deps = [ deps = [
":graph_builder", ":graph_builder",
":structured_graph_builder", ":structured_graph_builder",
":task_spec_py_pb2", ":task_spec_pb2_py",
"@absl_py//absl:app",
"@absl_py//absl/flags",
], ],
) )
...@@ -828,9 +893,11 @@ py_binary( ...@@ -828,9 +893,11 @@ py_binary(
srcs = ["parser_eval.py"], srcs = ["parser_eval.py"],
deps = [ deps = [
":graph_builder", ":graph_builder",
":sentence_py_pb2", ":sentence_pb2_py",
":structured_graph_builder", ":structured_graph_builder",
":task_spec_py_pb2", ":task_spec_pb2_py",
"@absl_py//absl:app",
"@absl_py//absl/flags",
], ],
) )
...@@ -839,7 +906,18 @@ py_binary( ...@@ -839,7 +906,18 @@ py_binary(
srcs = ["conll2tree.py"], srcs = ["conll2tree.py"],
deps = [ deps = [
":graph_builder", ":graph_builder",
":sentence_py_pb2", ":sentence_pb2_py",
"@absl_py//absl:app",
"@absl_py//absl/flags",
],
)
py_library(
name = "test_flags",
srcs = ["test_flags.py"],
deps = [
"@absl_py//absl/flags",
"@org_tensorflow//tensorflow:tensorflow_py",
], ],
) )
...@@ -851,8 +929,9 @@ py_test( ...@@ -851,8 +929,9 @@ py_test(
srcs = ["lexicon_builder_test.py"], srcs = ["lexicon_builder_test.py"],
deps = [ deps = [
":graph_builder", ":graph_builder",
":sentence_py_pb2", ":sentence_pb2_py",
":task_spec_py_pb2", ":task_spec_pb2_py",
"//syntaxnet:test_flags",
], ],
) )
...@@ -862,8 +941,9 @@ py_test( ...@@ -862,8 +941,9 @@ py_test(
srcs = ["text_formats_test.py"], srcs = ["text_formats_test.py"],
deps = [ deps = [
":graph_builder", ":graph_builder",
":sentence_py_pb2", ":sentence_pb2_py",
":task_spec_py_pb2", ":task_spec_pb2_py",
"//syntaxnet:test_flags",
], ],
) )
...@@ -874,9 +954,10 @@ py_test( ...@@ -874,9 +954,10 @@ py_test(
data = [":testdata"], data = [":testdata"],
tags = ["notsan"], tags = ["notsan"],
deps = [ deps = [
":dictionary_py_pb2", ":dictionary_pb2_py",
":graph_builder", ":graph_builder",
":sparse_py_pb2", ":sparse_pb2_py",
"//syntaxnet:test_flags",
], ],
) )
...@@ -888,6 +969,7 @@ py_test( ...@@ -888,6 +969,7 @@ py_test(
tags = ["notsan"], tags = ["notsan"],
deps = [ deps = [
":structured_graph_builder", ":structured_graph_builder",
"//syntaxnet:test_flags",
], ],
) )
...@@ -901,7 +983,8 @@ py_test( ...@@ -901,7 +983,8 @@ py_test(
tags = ["notsan"], tags = ["notsan"],
deps = [ deps = [
":graph_builder", ":graph_builder",
":sparse_py_pb2", ":sparse_pb2_py",
"//syntaxnet:test_flags",
], ],
) )
......
...@@ -22,6 +22,9 @@ limitations under the License. ...@@ -22,6 +22,9 @@ limitations under the License.
#include <unordered_set> #include <unordered_set>
#include <vector> #include <vector>
#include "google/protobuf/util/message_differencer.h"
#include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/lib/strings/stringprintf.h"
......
...@@ -20,32 +20,25 @@ import os.path ...@@ -20,32 +20,25 @@ import os.path
import time import time
import tensorflow as tf import tensorflow as tf
from tensorflow.python.framework import test_util
from tensorflow.python.platform import googletest
from tensorflow.python.platform import tf_logging as logging from tensorflow.python.platform import tf_logging as logging
from syntaxnet import structured_graph_builder from syntaxnet import structured_graph_builder
from syntaxnet import test_flags
from syntaxnet.ops import gen_parser_ops from syntaxnet.ops import gen_parser_ops
FLAGS = tf.app.flags.FLAGS
if not hasattr(FLAGS, 'test_srcdir'):
FLAGS.test_srcdir = ''
if not hasattr(FLAGS, 'test_tmpdir'):
FLAGS.test_tmpdir = tf.test.get_temp_dir()
class ParsingReaderOpsTest(tf.test.TestCase):
class ParsingReaderOpsTest(test_util.TensorFlowTestCase):
def setUp(self): def setUp(self):
# Creates a task context with the correct testing paths. # Creates a task context with the correct testing paths.
initial_task_context = os.path.join(FLAGS.test_srcdir, initial_task_context = os.path.join(test_flags.source_root(),
'syntaxnet/' 'syntaxnet/'
'testdata/context.pbtxt') 'testdata/context.pbtxt')
self._task_context = os.path.join(FLAGS.test_tmpdir, 'context.pbtxt') self._task_context = os.path.join(test_flags.temp_dir(), 'context.pbtxt')
with open(initial_task_context, 'r') as fin: with open(initial_task_context, 'r') as fin:
with open(self._task_context, 'w') as fout: with open(self._task_context, 'w') as fout:
fout.write(fin.read().replace('SRCDIR', FLAGS.test_srcdir) fout.write(fin.read().replace('SRCDIR', test_flags.source_root())
.replace('OUTPATH', FLAGS.test_tmpdir)) .replace('OUTPATH', test_flags.temp_dir()))
# Creates necessary term maps. # Creates necessary term maps.
with self.test_session() as sess: with self.test_session() as sess:
...@@ -225,4 +218,4 @@ class ParsingReaderOpsTest(test_util.TensorFlowTestCase): ...@@ -225,4 +218,4 @@ class ParsingReaderOpsTest(test_util.TensorFlowTestCase):
if __name__ == '__main__': if __name__ == '__main__':
googletest.main() tf.test.main()
...@@ -73,6 +73,7 @@ limitations under the License. ...@@ -73,6 +73,7 @@ limitations under the License.
#include "syntaxnet/registry.h" #include "syntaxnet/registry.h"
#include "syntaxnet/utils.h" #include "syntaxnet/utils.h"
#include "absl/base/macros.h"
// ===================================================================== // =====================================================================
// Registry for accessing CharProperties by name // Registry for accessing CharProperties by name
...@@ -128,7 +129,7 @@ struct CharPropertyWrapper : RegisterableClass<CharPropertyWrapper> { ...@@ -128,7 +129,7 @@ struct CharPropertyWrapper : RegisterableClass<CharPropertyWrapper> {
static const int k_##name##_unicodes[] = {unicodes}; \ static const int k_##name##_unicodes[] = {unicodes}; \
static utils::LazyStaticPtr<CharProperty, const char *, const int *, size_t> \ static utils::LazyStaticPtr<CharProperty, const char *, const int *, size_t> \
name##_char_property = {#name, k_##name##_unicodes, \ name##_char_property = {#name, k_##name##_unicodes, \
arraysize(k_##name##_unicodes)}; \ ABSL_ARRAYSIZE(k_##name##_unicodes)}; \
REGISTER_CHAR_PROPERTY(name##_char_property, name); \ REGISTER_CHAR_PROPERTY(name##_char_property, name); \
DEFINE_IS_X_CHAR_PROPERTY_FUNCTIONS(name##_char_property, name) DEFINE_IS_X_CHAR_PROPERTY_FUNCTIONS(name##_char_property, name)
......
...@@ -187,7 +187,7 @@ DEFINE_CHAR_PROPERTY(test_punctuation_plus, prop) { ...@@ -187,7 +187,7 @@ DEFINE_CHAR_PROPERTY(test_punctuation_plus, prop) {
prop->AddCharRange('b', 'b'); prop->AddCharRange('b', 'b');
prop->AddCharRange('c', 'e'); prop->AddCharRange('c', 'e');
static const int kUnicodes[] = {'f', RANGE('g', 'i'), 'j'}; static const int kUnicodes[] = {'f', RANGE('g', 'i'), 'j'};
prop->AddCharSpec(kUnicodes, arraysize(kUnicodes)); prop->AddCharSpec(kUnicodes, ABSL_ARRAYSIZE(kUnicodes));
prop->AddCharProperty("punctuation"); prop->AddCharProperty("punctuation");
} }
...@@ -223,25 +223,25 @@ const char32 kTestPunctuationPlusExtras[] = { ...@@ -223,25 +223,25 @@ const char32 kTestPunctuationPlusExtras[] = {
// //
TEST_F(CharPropertiesTest, TestDigit) { TEST_F(CharPropertiesTest, TestDigit) {
CollectArray(kTestDigit, arraysize(kTestDigit)); CollectArray(kTestDigit, ABSL_ARRAYSIZE(kTestDigit));
ExpectCharPropertyEqualsCollectedSet("test_digit"); ExpectCharPropertyEqualsCollectedSet("test_digit");
} }
TEST_F(CharPropertiesTest, TestWavyDash) { TEST_F(CharPropertiesTest, TestWavyDash) {
CollectArray(kTestWavyDash, arraysize(kTestWavyDash)); CollectArray(kTestWavyDash, ABSL_ARRAYSIZE(kTestWavyDash));
ExpectCharPropertyEqualsCollectedSet("test_wavy_dash"); ExpectCharPropertyEqualsCollectedSet("test_wavy_dash");
} }
TEST_F(CharPropertiesTest, TestDigitOrWavyDash) { TEST_F(CharPropertiesTest, TestDigitOrWavyDash) {
CollectArray(kTestDigit, arraysize(kTestDigit)); CollectArray(kTestDigit, ABSL_ARRAYSIZE(kTestDigit));
CollectArray(kTestWavyDash, arraysize(kTestWavyDash)); CollectArray(kTestWavyDash, ABSL_ARRAYSIZE(kTestWavyDash));
ExpectCharPropertyEqualsCollectedSet("test_digit_or_wavy_dash"); ExpectCharPropertyEqualsCollectedSet("test_digit_or_wavy_dash");
} }
TEST_F(CharPropertiesTest, TestPunctuationPlus) { TEST_F(CharPropertiesTest, TestPunctuationPlus) {
CollectCharProperty("punctuation"); CollectCharProperty("punctuation");
CollectArray(kTestPunctuationPlusExtras, CollectArray(kTestPunctuationPlusExtras,
arraysize(kTestPunctuationPlusExtras)); ABSL_ARRAYSIZE(kTestPunctuationPlusExtras));
ExpectCharPropertyEqualsCollectedSet("test_punctuation_plus"); ExpectCharPropertyEqualsCollectedSet("test_punctuation_plus");
} }
......
...@@ -110,7 +110,9 @@ bool CharShiftTransitionState::IsTokenEnd(int i) const { ...@@ -110,7 +110,9 @@ bool CharShiftTransitionState::IsTokenEnd(int i) const {
} }
void CharShiftTransitionSystem::Setup(TaskContext *context) { void CharShiftTransitionSystem::Setup(TaskContext *context) {
left_to_right_ = context->Get("left-to-right", true); // The version with underscores takes precedence if explicitly set.
left_to_right_ =
context->Get("left_to_right", context->Get("left-to-right", true));
} }
bool CharShiftTransitionSystem::IsAllowedAction( bool CharShiftTransitionSystem::IsAllowedAction(
......
...@@ -76,7 +76,7 @@ class CharShiftTransitionTest : public ::testing::Test { ...@@ -76,7 +76,7 @@ class CharShiftTransitionTest : public ::testing::Test {
} }
void PrepareCharTransition(bool left_to_right) { void PrepareCharTransition(bool left_to_right) {
context_.SetParameter("left-to-right", left_to_right ? "true" : "false"); context_.SetParameter("left_to_right", left_to_right ? "true" : "false");
transition_system_.reset(ParserTransitionSystem::Create("char-shift-only")); transition_system_.reset(ParserTransitionSystem::Create("char-shift-only"));
transition_system_->Setup(&context_); transition_system_->Setup(&context_);
...@@ -88,7 +88,7 @@ class CharShiftTransitionTest : public ::testing::Test { ...@@ -88,7 +88,7 @@ class CharShiftTransitionTest : public ::testing::Test {
} }
void PrepareShiftTransition(bool left_to_right) { void PrepareShiftTransition(bool left_to_right) {
context_.SetParameter("left-to-right", left_to_right ? "true" : "false"); context_.SetParameter("left_to_right", left_to_right ? "true" : "false");
transition_system_.reset(ParserTransitionSystem::Create("shift-only")); transition_system_.reset(ParserTransitionSystem::Create("shift-only"));
transition_system_->Setup(&context_); transition_system_->Setup(&context_);
state_.reset(new ParserState( state_.reset(new ParserState(
......
...@@ -17,6 +17,8 @@ ...@@ -17,6 +17,8 @@
import collections import collections
import re import re
from absl import app
from absl import flags
import asciitree import asciitree
import tensorflow as tf import tensorflow as tf
...@@ -26,7 +28,6 @@ from tensorflow.python.platform import tf_logging as logging ...@@ -26,7 +28,6 @@ from tensorflow.python.platform import tf_logging as logging
from syntaxnet import sentence_pb2 from syntaxnet import sentence_pb2
from syntaxnet.ops import gen_parser_ops from syntaxnet.ops import gen_parser_ops
flags = tf.app.flags
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
flags.DEFINE_string('task_context', flags.DEFINE_string('task_context',
...@@ -88,16 +89,16 @@ def main(unused_argv): ...@@ -88,16 +89,16 @@ def main(unused_argv):
sentence.ParseFromString(d) sentence.ParseFromString(d)
tr = asciitree.LeftAligned() tr = asciitree.LeftAligned()
d = to_dict(sentence) d = to_dict(sentence)
print('Input: %s' % sentence.text) print 'Input: %s' % sentence.text
print('Parse:') print 'Parse:'
tr_str = tr(d) tr_str = tr(d)
pat = re.compile(r'\s*@\d+$') pat = re.compile(r'\s*@\d+$')
for tr_ln in tr_str.splitlines(): for tr_ln in tr_str.splitlines():
print(pat.sub('', tr_ln)) print pat.sub('', tr_ln)
if finished: if finished:
break break
if __name__ == '__main__': if __name__ == '__main__':
tf.app.run() app.run(main)
...@@ -101,6 +101,13 @@ int GenericFeatureFunction::GetIntParameter(const string &name, ...@@ -101,6 +101,13 @@ int GenericFeatureFunction::GetIntParameter(const string &name,
tensorflow::strings::safe_strto32); tensorflow::strings::safe_strto32);
} }
double GenericFeatureFunction::GetFloatParameter(const string &name,
double default_value) const {
const string value = GetParameter(name);
return utils::ParseUsing<double>(value, default_value,
tensorflow::strings::safe_strtod);
}
bool GenericFeatureFunction::GetBoolParameter(const string &name, bool GenericFeatureFunction::GetBoolParameter(const string &name,
bool default_value) const { bool default_value) const {
const string value = GetParameter(name); const string value = GetParameter(name);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment