Unverified Commit 80178fc6 authored by Mark Omernick's avatar Mark Omernick Committed by GitHub
Browse files

Merge pull request #4153 from terryykoo/master

Export @195097388.
parents a84e1ef9 edea2b67
......@@ -26,6 +26,7 @@ import os
import re
import time
from absl import flags
import tensorflow as tf
from google.protobuf import text_format
......@@ -39,7 +40,6 @@ from dragnn.python import sentence_io
from dragnn.python import spec_builder
from syntaxnet import sentence_pb2
flags = tf.app.flags
FLAGS = flags.FLAGS
flags.DEFINE_string('master_spec', '',
......
......@@ -19,6 +19,7 @@ r"""Runs a both a segmentation and parsing model on a CoNLL dataset.
import re
import time
from absl import flags
import tensorflow as tf
from google.protobuf import text_format
......@@ -34,7 +35,6 @@ from syntaxnet import sentence_pb2
from syntaxnet.ops import gen_parser_ops
from syntaxnet.util import check
flags = tf.app.flags
FLAGS = flags.FLAGS
......
......@@ -42,6 +42,8 @@ import ast
import collections
import os
import os.path
from absl import app
from absl import flags
import tensorflow as tf
from google.protobuf import text_format
......@@ -55,7 +57,6 @@ from dragnn.python import trainer_lib
from syntaxnet.ops import gen_parser_ops
from syntaxnet.util import check
flags = tf.app.flags
FLAGS = flags.FLAGS
flags.DEFINE_string('tf_master', '',
......@@ -191,4 +192,4 @@ def main(unused_argv):
if __name__ == '__main__':
tf.app.run()
app.run(main)
# Copyright 2017 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
r"""Runs a both a segmentation and parsing model on a CoNLL dataset.
"""
import re
import time
import tensorflow as tf
from google.protobuf import text_format
from tensorflow.python.client import timeline
from tensorflow.python.platform import gfile
from dragnn.protos import spec_pb2
from dragnn.python import graph_builder
from dragnn.python import sentence_io
from dragnn.python import spec_builder
from syntaxnet import sentence_pb2
from syntaxnet.ops import gen_parser_ops
from syntaxnet.util import check
flags = tf.app.flags
FLAGS = flags.FLAGS
flags.DEFINE_string('parser_master_spec', '',
'Path to text file containing a DRAGNN master spec to run.')
flags.DEFINE_string('parser_checkpoint_file', '',
'Path to trained model checkpoint.')
flags.DEFINE_string('parser_resource_dir', '',
'Optional base directory for resources in the master spec.')
flags.DEFINE_string('segmenter_master_spec', '',
'Path to text file containing a DRAGNN master spec to run.')
flags.DEFINE_string('segmenter_checkpoint_file', '',
'Path to trained model checkpoint.')
flags.DEFINE_string('segmenter_resource_dir', '',
'Optional base directory for resources in the master spec.')
flags.DEFINE_bool('complete_master_spec', True, 'Whether the master_specs '
'needs the lexicon and other resources added to them.')
flags.DEFINE_string('input_file', '',
'File of CoNLL-formatted sentences to read from.')
flags.DEFINE_string('output_file', '',
'File path to write annotated sentences to.')
flags.DEFINE_integer('max_batch_size', 2048, 'Maximum batch size to support.')
flags.DEFINE_string('inference_beam_size', '', 'Comma separated list of '
'component_name=beam_size pairs.')
flags.DEFINE_string('locally_normalize', '', 'Comma separated list of '
'component names to do local normalization on.')
flags.DEFINE_integer('threads', 10, 'Number of threads used for intra- and '
'inter-op parallelism.')
flags.DEFINE_string('timeline_output_file', '', 'Path to save timeline to. '
'If specified, the final iteration of the evaluation loop '
'will capture and save a TensorFlow timeline.')
flags.DEFINE_bool('use_gold_segmentation', False,
'Whether or not to use gold segmentation.')
def main(unused_argv):
# Parse the flags containint lists, using regular expressions.
# This matches and extracts key=value pairs.
component_beam_sizes = re.findall(r'([^=,]+)=(\d+)',
FLAGS.inference_beam_size)
# This matches strings separated by a comma. Does not return any empty
# strings.
components_to_locally_normalize = re.findall(r'[^,]+',
FLAGS.locally_normalize)
## SEGMENTATION ##
if not FLAGS.use_gold_segmentation:
# Reads master spec.
master_spec = spec_pb2.MasterSpec()
with gfile.FastGFile(FLAGS.segmenter_master_spec) as fin:
text_format.Parse(fin.read(), master_spec)
if FLAGS.complete_master_spec:
spec_builder.complete_master_spec(
master_spec, None, FLAGS.segmenter_resource_dir)
# Graph building.
tf.logging.info('Building the graph')
g = tf.Graph()
with g.as_default(), tf.device('/device:CPU:0'):
hyperparam_config = spec_pb2.GridPoint()
hyperparam_config.use_moving_average = True
builder = graph_builder.MasterBuilder(master_spec, hyperparam_config)
annotator = builder.add_annotation()
builder.add_saver()
tf.logging.info('Reading documents...')
input_corpus = sentence_io.ConllSentenceReader(FLAGS.input_file).corpus()
with tf.Session(graph=tf.Graph()) as tmp_session:
char_input = gen_parser_ops.char_token_generator(input_corpus)
char_corpus = tmp_session.run(char_input)
check.Eq(len(input_corpus), len(char_corpus))
session_config = tf.ConfigProto(
log_device_placement=False,
intra_op_parallelism_threads=FLAGS.threads,
inter_op_parallelism_threads=FLAGS.threads)
with tf.Session(graph=g, config=session_config) as sess:
tf.logging.info('Initializing variables...')
sess.run(tf.global_variables_initializer())
tf.logging.info('Loading from checkpoint...')
sess.run('save/restore_all',
{'save/Const:0': FLAGS.segmenter_checkpoint_file})
tf.logging.info('Processing sentences...')
processed = []
start_time = time.time()
run_metadata = tf.RunMetadata()
for start in range(0, len(char_corpus), FLAGS.max_batch_size):
end = min(start + FLAGS.max_batch_size, len(char_corpus))
feed_dict = {annotator['input_batch']: char_corpus[start:end]}
if FLAGS.timeline_output_file and end == len(char_corpus):
serialized_annotations = sess.run(
annotator['annotations'], feed_dict=feed_dict,
options=tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE),
run_metadata=run_metadata)
trace = timeline.Timeline(step_stats=run_metadata.step_stats)
with open(FLAGS.timeline_output_file, 'w') as trace_file:
trace_file.write(trace.generate_chrome_trace_format())
else:
serialized_annotations = sess.run(
annotator['annotations'], feed_dict=feed_dict)
processed.extend(serialized_annotations)
tf.logging.info('Processed %d documents in %.2f seconds.',
len(char_corpus), time.time() - start_time)
input_corpus = processed
else:
input_corpus = sentence_io.ConllSentenceReader(FLAGS.input_file).corpus()
## PARSING
# Reads master spec.
master_spec = spec_pb2.MasterSpec()
with gfile.FastGFile(FLAGS.parser_master_spec) as fin:
text_format.Parse(fin.read(), master_spec)
if FLAGS.complete_master_spec:
spec_builder.complete_master_spec(
master_spec, None, FLAGS.parser_resource_dir)
# Graph building.
tf.logging.info('Building the graph')
g = tf.Graph()
with g.as_default(), tf.device('/device:CPU:0'):
hyperparam_config = spec_pb2.GridPoint()
hyperparam_config.use_moving_average = True
builder = graph_builder.MasterBuilder(master_spec, hyperparam_config)
annotator = builder.add_annotation()
builder.add_saver()
tf.logging.info('Reading documents...')
session_config = tf.ConfigProto(
log_device_placement=False,
intra_op_parallelism_threads=FLAGS.threads,
inter_op_parallelism_threads=FLAGS.threads)
with tf.Session(graph=g, config=session_config) as sess:
tf.logging.info('Initializing variables...')
sess.run(tf.global_variables_initializer())
tf.logging.info('Loading from checkpoint...')
sess.run('save/restore_all', {'save/Const:0': FLAGS.parser_checkpoint_file})
tf.logging.info('Processing sentences...')
processed = []
start_time = time.time()
run_metadata = tf.RunMetadata()
for start in range(0, len(input_corpus), FLAGS.max_batch_size):
end = min(start + FLAGS.max_batch_size, len(input_corpus))
feed_dict = {annotator['input_batch']: input_corpus[start:end]}
for comp, beam_size in component_beam_sizes:
feed_dict['%s/InferenceBeamSize:0' % comp] = beam_size
for comp in components_to_locally_normalize:
feed_dict['%s/LocallyNormalize:0' % comp] = True
if FLAGS.timeline_output_file and end == len(input_corpus):
serialized_annotations = sess.run(
annotator['annotations'], feed_dict=feed_dict,
options=tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE),
run_metadata=run_metadata)
trace = timeline.Timeline(step_stats=run_metadata.step_stats)
with open(FLAGS.timeline_output_file, 'w') as trace_file:
trace_file.write(trace.generate_chrome_trace_format())
else:
serialized_annotations = sess.run(
annotator['annotations'], feed_dict=feed_dict)
processed.extend(serialized_annotations)
tf.logging.info('Processed %d documents in %.2f seconds.',
len(input_corpus), time.time() - start_time)
if FLAGS.output_file:
with gfile.GFile(FLAGS.output_file, 'w') as f:
for serialized_sentence in processed:
sentence = sentence_pb2.Sentence()
sentence.ParseFromString(serialized_sentence)
f.write('#' + sentence.text.encode('utf-8') + '\n')
for i, token in enumerate(sentence.token):
head = token.head + 1
f.write('%s\t%s\t_\t_\t_\t_\t%d\t%s\t_\t_\n'%(
i + 1,
token.word.encode('utf-8'), head,
token.label.encode('utf-8')))
f.write('\n\n')
if __name__ == '__main__':
tf.app.run()
......@@ -17,6 +17,7 @@ r"""Runs a both a segmentation and parsing model on a CoNLL dataset.
import re
import time
from absl import flags
import tensorflow as tf
from tensorflow.python.client import timeline
......@@ -35,7 +36,6 @@ from syntaxnet import syntaxnet_ops
from syntaxnet.ops import gen_parser_ops
from syntaxnet.util import check
flags = tf.app.flags
FLAGS = flags.FLAGS
flags.DEFINE_string(
......
......@@ -21,6 +21,8 @@ import os
import os.path
import random
import time
from absl import app
from absl import flags
import tensorflow as tf
from tensorflow.python.platform import gfile
......@@ -40,7 +42,6 @@ from dragnn.python import sentence_io
from dragnn.python import spec_builder
from dragnn.python import trainer_lib
flags = tf.app.flags
FLAGS = flags.FLAGS
flags.DEFINE_string('tf_master', '',
......@@ -189,4 +190,4 @@ def main(unused_argv):
if __name__ == '__main__':
tf.app.run()
app.run(main)
......@@ -27,6 +27,7 @@ import os
import re
import time
from absl import flags
import tensorflow as tf
from google.protobuf import text_format
......@@ -42,7 +43,6 @@ from syntaxnet import sentence_pb2
from syntaxnet.ops import gen_parser_ops
from syntaxnet.util import check
flags = tf.app.flags
FLAGS = flags.FLAGS
flags.DEFINE_string('master_spec', '',
......
......@@ -22,6 +22,7 @@ import os
import os.path
import random
import time
from absl import flags
import tensorflow as tf
from tensorflow.python.platform import gfile
......@@ -42,7 +43,6 @@ from dragnn.python import lexicon
from dragnn.python import spec_builder
from dragnn.python import trainer_lib
flags = tf.app.flags
FLAGS = flags.FLAGS
flags.DEFINE_string('tf_master', '',
......
......@@ -22,6 +22,7 @@ import os
import os.path
import random
import time
from absl import flags
import tensorflow as tf
from tensorflow.python.framework import errors
......@@ -45,7 +46,6 @@ from dragnn.python import trainer_lib
from syntaxnet.util import check
flags = tf.app.flags
FLAGS = flags.FLAGS
flags.DEFINE_string('tf_master', '',
......
/**
* Template for node info.
*/
goog.module('nlp.saft.opensource.dragnn.viz.node_info');
import preact from 'preact';
import _ from 'lodash';
const normalCell = {
'border': 0,
'border-collapse': 'separate',
'padding': '2px',
};
/**
* Style definitions which are directly injected (see README.md comments).
*/
const style = {
featuresTable: {
'background-color': 'rgba(255, 255, 255, 0.9)',
'border': '1px solid #dddddd',
'border-spacing': '2px',
'border-collapse': 'separate',
'font-family': 'roboto, helvectica, arial, sans-serif',
// Sometimes state strings (`stateHtml`) get long, and because this is an
// absolutely-positioned box, we need to make them wrap around.
'max-width': '600px',
'position': 'absolute',
},
heading: {
'background-color': '#ebf5fb',
'font-weight': 'bold',
'text-align': 'center',
...normalCell
},
normalCell: normalCell,
featureGroup: (componentColor) => ({
'background-color': componentColor,
'font-weight': 'bold',
...normalCell
}),
normalRow: {
'border': 0,
'border-collapse': 'separate',
},
};
/**
* Creates table rows that negate IPython/Jupyter notebook styling.
*
* @param {?XML|?Array<XML>} children Child nodes. (Recall Preact handles
* null/undefined gracefully).
* @param {!Object} props Any additional properties.
* @return {!XML} React-y element, representing a table row.
*/
const Row = ({children, ...props}) => (
<tr style={style.normalRow} {...props}>{children}</tr>);
/**
* Creates table cells that negate IPython/Jupyter notebook styling.
*
* @param {?XML|?Array<XML>} children Child nodes. (Recall Preact handles
* null/undefined gracefully).
* @param {!Object} props Any additional properties.
* @return {!XML} React-y element, representing a table cell.
*/
const Cell = ({children, ...props}) => (
<td style={style.normalCell} {...props}>{children}</td>);
/**
* Construct a table "multi-row" with a shared "header" cell.
*
* In ASCII-art,
*
* ------------------------------
* | row1
* header | row2
* | row3
* ------------------------------
*
* @param {string} headerText Text for the header cell
* @param {string} headerColor Color of the header cell
* @param {!Array<XML>} rowsCells Row cells (<td> React-y elements).
* @return {!Array<XML>} Array of React-y elements.
*/
const featureGroup = (headerText, headerColor, rowsCells) => {
const headerCell = (
<td rowspan={rowsCells.length} style={style.featureGroup(headerColor)}>
{headerText}
</td>
);
return _.map(rowsCells, (cells, i) => {
return <Row>{i == 0 ? headerCell : null}{cells}</Row>;
});
};
/**
* Mini helper to intersperse line breaks with a list of elements.
*
* This just replicates previous behavior and looks OK; we could also try spans
* with `display: 'block'` or such.
*
* @param {!Array<XML>} elements React-y elements.
* @return {!Array<XML>} React-y elements with line breaks.
*/
const intersperseLineBreaks = (elements) => _.tail(_.flatten(_.map(
elements, (v) => [<br />, v]
)));
export default class NodeInfo extends preact.Component {
/**
* Obligatory Preact render() function.
*
* It might be worthwhile converting some of the intermediate variables into
* stateless functional components, like Cell and Row.
*
* @param {?Object} selected Cytoscape node selected (null if no selection).
* @param {?Object} mousePosition Mouse position, if a node is selected.
* @return {!XML} Preact components to render.
*/
render({selected, mousePosition}) {
const visible = selected != null;
const stateHtml = visible && selected.data('stateInfo');
// Generates elements for fixed features.
const fixedFeatures = visible ? selected.data('fixedFeatures') : [];
const fixedFeatureElements = _.map(fixedFeatures, (feature) => {
if (feature.value_trace.length == 0) {
// Preact will just prune this out.
return null;
} else {
const rowsCells = _.map(feature.value_trace, (value) => {
// Recall `value_name` is a list of strings (representing feature
// values), but this is OK because strings are valid react elements.
const valueCells = intersperseLineBreaks(value.value_name);
return [<Cell>{value.feature_name}</Cell>, <Cell>{valueCells}</Cell>];
});
return featureGroup(feature.name, '#cccccc', _.map(rowsCells));
}
});
/**
* Generates linked feature info from an edge.
*
* @param {!Object} edge Cytoscape JS Element representing a linked feature.
* @return {[XML,XML]} Linked feature information, as table elements.
*/
const linkedFeatureInfoFromEdge = (edge) => {
return [
<Cell>{edge.data('featureName')}</Cell>,
<Cell>
value {edge.data('featureValue')} from
step {edge.source().data('stepIdx')}
</Cell>
];
};
const linkedFeatureElements = _.flatten(
_.map(this.edgeStatesByComponent(), (edges, componentName) => {
// Because edges are generated by `incomers`, it is guaranteed to be
// non-empty.
const color = _.head(edges).source().parent().data('componentColor');
const rowsCells = _.map(edges, linkedFeatureInfoFromEdge);
return featureGroup(componentName, color, rowsCells);
}));
let positionOrHiddenStyle;
if (visible) {
positionOrHiddenStyle = {
left: mousePosition.x + 20,
top: mousePosition.y + 10,
};
} else {
positionOrHiddenStyle = {display: 'none'};
}
return (
<table style={_.defaults(positionOrHiddenStyle, style.featuresTable)}>
<Row>
<td colspan="3" style={style.heading}>State</td>
</Row>
<Row>
<Cell colspan="3">{stateHtml}</Cell>
</Row>
<Row>
<td colspan="3" style={style.heading}>Features</td>
</Row>
{fixedFeatureElements}
{linkedFeatureElements}
</table>
);
}
/**
* Gets a list of incoming edges, grouped by their component name.
*
* @return {!Object<string, !Array<!Object>>} Map from component name to list
* of edges.
*/
edgeStatesByComponent() {
if (this.props.selected == null) {
return [];
}
const incoming = this.props.selected.incomers(); // edges and nodes
return _.groupBy(incoming.edges(), (edge) => edge.source().parent().id());
}
}
......@@ -17,7 +17,7 @@ py_library(
deps = [
"//dragnn/core:dragnn_bulk_ops",
"//dragnn/core:dragnn_ops",
"//dragnn/protos:spec_py_pb2",
"//dragnn/protos:spec_pb2_py",
"//dragnn/python:graph_builder",
"//dragnn/python:lexicon",
"//dragnn/python:load_dragnn_cc_impl_py",
......@@ -25,7 +25,7 @@ py_library(
"//dragnn/python:visualization",
"//syntaxnet:load_parser_ops_py",
"//syntaxnet:parser_ops",
"//syntaxnet:sentence_py_pb2",
"//syntaxnet:sentence_pb2_py",
"@org_tensorflow//tensorflow:tensorflow_py",
"@org_tensorflow//tensorflow/core:protos_all_py",
],
......@@ -34,6 +34,7 @@ py_library(
filegroup(
name = "data",
data = glob(["tutorial_data/*"]),
visibility = ["//visibility:public"],
)
sh_test(
......
......@@ -11,63 +11,66 @@ package(
licenses(["notice"]) # Apache 2.0
load(
"syntaxnet",
"tf_proto_library",
"tf_proto_library_py",
"@org_tensorflow//tensorflow:tensorflow.bzl",
"tf_gen_op_libs",
"tf_gen_op_wrapper_py",
)
load(
":syntaxnet.bzl",
"tf_proto_library_cc",
"tf_proto_library_py",
)
# proto libraries
tf_proto_library(
tf_proto_library_cc(
name = "feature_extractor_proto",
srcs = ["feature_extractor.proto"],
)
tf_proto_library(
tf_proto_library_cc(
name = "sentence_proto",
srcs = ["sentence.proto"],
)
tf_proto_library_py(
name = "sentence_py_pb2",
name = "sentence_pb2",
srcs = ["sentence.proto"],
)
tf_proto_library(
tf_proto_library_cc(
name = "dictionary_proto",
srcs = ["dictionary.proto"],
)
tf_proto_library_py(
name = "dictionary_py_pb2",
name = "dictionary_pb2",
srcs = ["dictionary.proto"],
)
tf_proto_library(
tf_proto_library_cc(
name = "kbest_syntax_proto",
srcs = ["kbest_syntax.proto"],
deps = [":sentence_proto"],
protodeps = [":sentence_proto"],
)
tf_proto_library(
tf_proto_library_cc(
name = "task_spec_proto",
srcs = ["task_spec.proto"],
)
tf_proto_library_py(
name = "task_spec_py_pb2",
name = "task_spec_pb2",
srcs = ["task_spec.proto"],
)
tf_proto_library(
tf_proto_library_cc(
name = "sparse_proto",
srcs = ["sparse.proto"],
)
tf_proto_library_py(
name = "sparse_py_pb2",
name = "sparse_pb2",
srcs = ["sparse.proto"],
)
......@@ -79,11 +82,10 @@ cc_library(
visibility = ["//visibility:public"],
deps = [
"@com_googlesource_code_re2//:re2",
"@protobuf_archive//:protobuf",
"@org_tensorflow//third_party/eigen3",
] + select({
"//conditions:default": [
"@org_tensorflow//tensorflow/core:framework",
"@org_tensorflow//tensorflow/core:framework_headers_lib",
"@org_tensorflow//tensorflow/core:lib",
],
"@org_tensorflow//tensorflow:darwin": [
......@@ -122,7 +124,7 @@ cc_library(
hdrs = ["document_format.h"],
deps = [
":registry",
":sentence_proto",
":sentence_proto_cc",
":task_context",
],
)
......@@ -134,7 +136,7 @@ cc_library(
":base",
":document_format",
":segmenter_utils",
":sentence_proto",
":sentence_proto_cc",
],
alwayslink = 1,
)
......@@ -144,7 +146,7 @@ cc_library(
srcs = ["fml_parser.cc"],
hdrs = ["fml_parser.h"],
deps = [
":feature_extractor_proto",
":feature_extractor_proto_cc",
":utils",
],
)
......@@ -153,9 +155,9 @@ cc_library(
name = "proto_io",
hdrs = ["proto_io.h"],
deps = [
":feature_extractor_proto",
":feature_extractor_proto_cc",
":fml_parser",
":sentence_proto",
":sentence_proto_cc",
":task_context",
],
)
......@@ -168,6 +170,7 @@ cc_library(
":registry",
":utils",
"//util/utf8:unicodetext",
"@com_google_absl//absl/base:core_headers",
],
alwayslink = 1,
)
......@@ -190,7 +193,7 @@ cc_library(
deps = [
":base",
":char_properties",
":sentence_proto",
":sentence_proto_cc",
"//util/utf8:unicodetext",
],
alwayslink = 1,
......@@ -205,9 +208,9 @@ cc_library(
],
deps = [
":document_format",
":feature_extractor_proto",
":feature_extractor_proto_cc",
":proto_io",
":sentence_proto",
":sentence_proto_cc",
":task_context",
":utils",
":workspace",
......@@ -219,9 +222,9 @@ cc_library(
srcs = ["affix.cc"],
hdrs = ["affix.h"],
deps = [
":dictionary_proto",
":dictionary_proto_cc",
":feature_extractor",
":sentence_proto",
":sentence_proto_cc",
":shared_store",
":term_frequency_map",
":utils",
......@@ -276,7 +279,9 @@ cc_library(
srcs = ["registry.cc"],
hdrs = ["registry.h"],
deps = [
":base",
":utils",
"@org_tensorflow//tensorflow/core:lib",
],
)
......@@ -294,7 +299,7 @@ cc_library(
srcs = ["task_context.cc"],
hdrs = ["task_context.h"],
deps = [
":task_spec_proto",
":task_spec_proto_cc",
":utils",
],
)
......@@ -307,7 +312,6 @@ cc_library(
deps = [
":utils",
],
alwayslink = 1,
)
cc_library(
......@@ -319,7 +323,7 @@ cc_library(
":feature_extractor",
":proto_io",
":registry",
":sentence_proto",
":sentence_proto_cc",
":utils",
],
)
......@@ -360,7 +364,7 @@ cc_library(
":registry",
":segmenter_utils",
":sentence_features",
":sentence_proto",
":sentence_proto_cc",
":shared_store",
":task_context",
":term_frequency_map",
......@@ -377,10 +381,10 @@ cc_library(
srcs = ["populate_test_inputs.cc"],
hdrs = ["populate_test_inputs.h"],
deps = [
":dictionary_proto",
":sentence_proto",
":dictionary_proto_cc",
":sentence_proto_cc",
":task_context",
":task_spec_proto",
":task_spec_proto_cc",
":term_frequency_map",
":test_main",
],
......@@ -395,7 +399,7 @@ cc_library(
":feature_extractor",
":parser_transitions",
":sentence_features",
":sparse_proto",
":sparse_proto_cc",
":task_context",
":utils",
":workspace",
......@@ -420,10 +424,10 @@ cc_library(
":embedding_feature_extractor",
":feature_extractor",
":parser_transitions",
":sentence_proto",
":sparse_proto",
":sentence_proto_cc",
":sparse_proto_cc",
":task_context",
":task_spec_proto",
":task_spec_proto_cc",
":term_frequency_map",
":workspace",
],
......@@ -438,10 +442,10 @@ cc_library(
deps = [
":parser_transitions",
":sentence_batch",
":sentence_proto",
":sparse_proto",
":sentence_proto_cc",
":sparse_proto_cc",
":task_context",
":task_spec_proto",
":task_spec_proto_cc",
],
alwayslink = 1,
)
......@@ -454,7 +458,7 @@ cc_library(
":parser_transitions",
":segmenter_utils",
":sentence_batch",
":sentence_proto",
":sentence_proto_cc",
":task_context",
":text_formats",
],
......@@ -472,7 +476,7 @@ cc_library(
":parser_transitions",
":segmenter_utils",
":sentence_batch",
":sentence_proto",
":sentence_proto_cc",
":term_frequency_map",
":text_formats",
":utils",
......@@ -484,12 +488,20 @@ cc_library(
name = "unpack_sparse_features",
srcs = ["unpack_sparse_features.cc"],
deps = [
":sparse_proto",
":sparse_proto_cc",
":utils",
],
alwayslink = 1,
)
cc_library(
name = "shape_helpers",
hdrs = ["ops/shape_helpers.h"],
deps = [
"@org_tensorflow//tensorflow/core:framework_headers_lib",
],
)
cc_library(
name = "parser_ops_cc",
srcs = ["ops/parser_ops.cc"],
......@@ -498,6 +510,7 @@ cc_library(
":document_filters",
":lexicon_builder",
":reader_ops",
":shape_helpers",
":unpack_sparse_features",
],
alwayslink = 1,
......@@ -581,7 +594,7 @@ cc_test(
deps = [
":base",
":segmenter_utils",
":sentence_proto",
":sentence_proto_cc",
":test_main",
],
)
......@@ -605,9 +618,9 @@ cc_test(
":feature_extractor",
":populate_test_inputs",
":sentence_features",
":sentence_proto",
":sentence_proto_cc",
":task_context",
":task_spec_proto",
":task_spec_proto_cc",
":term_frequency_map",
":test_main",
":utils",
......@@ -622,7 +635,7 @@ cc_test(
deps = [
":feature_extractor",
":parser_transitions",
":sentence_proto",
":sentence_proto_cc",
":task_context",
":term_frequency_map",
":test_main",
......@@ -648,8 +661,8 @@ cc_test(
deps = [
":parser_transitions",
":populate_test_inputs",
":sentence_proto",
":task_spec_proto",
":sentence_proto_cc",
":task_spec_proto_cc",
":test_main",
],
)
......@@ -662,8 +675,8 @@ cc_test(
deps = [
":parser_transitions",
":populate_test_inputs",
":sentence_proto",
":task_spec_proto",
":sentence_proto_cc",
":task_spec_proto_cc",
":test_main",
],
)
......@@ -674,7 +687,7 @@ cc_test(
srcs = ["binary_segment_transitions_test.cc"],
deps = [
":parser_transitions",
":sentence_proto",
":sentence_proto_cc",
":task_context",
":test_main",
":workspace",
......@@ -689,8 +702,8 @@ cc_test(
deps = [
":parser_transitions",
":populate_test_inputs",
":sentence_proto",
":task_spec_proto",
":sentence_proto_cc",
":task_spec_proto_cc",
":test_main",
],
)
......@@ -702,7 +715,7 @@ cc_test(
deps = [
":base",
":parser_transitions",
":sentence_proto",
":sentence_proto_cc",
":task_context",
":term_frequency_map",
":test_main",
......@@ -716,7 +729,7 @@ cc_test(
deps = [
":base",
":parser_transitions",
":sentence_proto",
":sentence_proto_cc",
":task_context",
":term_frequency_map",
":test_main",
......@@ -730,7 +743,7 @@ cc_test(
deps = [
":base",
":parser_transitions",
":sentence_proto",
":sentence_proto_cc",
":task_context",
":term_frequency_map",
":test_main",
......@@ -744,7 +757,7 @@ cc_test(
deps = [
":base",
":parser_transitions",
":sentence_proto",
":sentence_proto_cc",
":task_context",
":term_frequency_map",
":test_main",
......@@ -759,19 +772,69 @@ cc_test(
":feature_extractor",
":parser_transitions",
":populate_test_inputs",
":sentence_proto",
":sentence_proto_cc",
":task_context",
":task_spec_proto",
":task_spec_proto_cc",
":term_frequency_map",
":test_main",
":workspace",
],
)
cc_test(
name = "term_frequency_map_test",
size = "small",
srcs = ["term_frequency_map_test.cc"],
deps = [
":base",
":term_frequency_map",
":test_main",
],
)
cc_test(
name = "fml_parser_test",
srcs = ["fml_parser_test.cc"],
deps = [
":base",
":feature_extractor_proto_cc",
":fml_parser",
":test_main",
"@org_tensorflow//tensorflow/core:lib",
"@org_tensorflow//tensorflow/core:test",
],
)
cc_test(
name = "registry_test",
srcs = ["registry_test.cc"],
deps = [
":base",
":registry",
":test_main",
"//dragnn/core/test:generic",
"@org_tensorflow//tensorflow/core:test",
],
)
cc_test(
name = "registry_test_with_duplicate",
srcs = ["registry_test.cc"],
defines = ["DRAGNN_REGISTRY_TEST_WITH_DUPLICATE"],
deps = [
":base",
":registry",
":test_main",
"//dragnn/core/test:generic",
"@org_tensorflow//tensorflow/core:test",
],
)
# py graph builder and trainer
tf_gen_op_libs(
op_lib_names = ["parser_ops"],
deps = [":shape_helpers"],
)
tf_gen_op_wrapper_py(
......@@ -819,7 +882,9 @@ py_binary(
deps = [
":graph_builder",
":structured_graph_builder",
":task_spec_py_pb2",
":task_spec_pb2_py",
"@absl_py//absl:app",
"@absl_py//absl/flags",
],
)
......@@ -828,9 +893,11 @@ py_binary(
srcs = ["parser_eval.py"],
deps = [
":graph_builder",
":sentence_py_pb2",
":sentence_pb2_py",
":structured_graph_builder",
":task_spec_py_pb2",
":task_spec_pb2_py",
"@absl_py//absl:app",
"@absl_py//absl/flags",
],
)
......@@ -839,7 +906,18 @@ py_binary(
srcs = ["conll2tree.py"],
deps = [
":graph_builder",
":sentence_py_pb2",
":sentence_pb2_py",
"@absl_py//absl:app",
"@absl_py//absl/flags",
],
)
py_library(
name = "test_flags",
srcs = ["test_flags.py"],
deps = [
"@absl_py//absl/flags",
"@org_tensorflow//tensorflow:tensorflow_py",
],
)
......@@ -851,8 +929,9 @@ py_test(
srcs = ["lexicon_builder_test.py"],
deps = [
":graph_builder",
":sentence_py_pb2",
":task_spec_py_pb2",
":sentence_pb2_py",
":task_spec_pb2_py",
"//syntaxnet:test_flags",
],
)
......@@ -862,8 +941,9 @@ py_test(
srcs = ["text_formats_test.py"],
deps = [
":graph_builder",
":sentence_py_pb2",
":task_spec_py_pb2",
":sentence_pb2_py",
":task_spec_pb2_py",
"//syntaxnet:test_flags",
],
)
......@@ -874,9 +954,10 @@ py_test(
data = [":testdata"],
tags = ["notsan"],
deps = [
":dictionary_py_pb2",
":dictionary_pb2_py",
":graph_builder",
":sparse_py_pb2",
":sparse_pb2_py",
"//syntaxnet:test_flags",
],
)
......@@ -888,6 +969,7 @@ py_test(
tags = ["notsan"],
deps = [
":structured_graph_builder",
"//syntaxnet:test_flags",
],
)
......@@ -901,7 +983,8 @@ py_test(
tags = ["notsan"],
deps = [
":graph_builder",
":sparse_py_pb2",
":sparse_pb2_py",
"//syntaxnet:test_flags",
],
)
......
......@@ -22,6 +22,9 @@ limitations under the License.
#include <unordered_set>
#include <vector>
#include "google/protobuf/util/message_differencer.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/lib/strings/stringprintf.h"
......
......@@ -20,32 +20,25 @@ import os.path
import time
import tensorflow as tf
from tensorflow.python.framework import test_util
from tensorflow.python.platform import googletest
from tensorflow.python.platform import tf_logging as logging
from syntaxnet import structured_graph_builder
from syntaxnet import test_flags
from syntaxnet.ops import gen_parser_ops
FLAGS = tf.app.flags.FLAGS
if not hasattr(FLAGS, 'test_srcdir'):
FLAGS.test_srcdir = ''
if not hasattr(FLAGS, 'test_tmpdir'):
FLAGS.test_tmpdir = tf.test.get_temp_dir()
class ParsingReaderOpsTest(test_util.TensorFlowTestCase):
class ParsingReaderOpsTest(tf.test.TestCase):
def setUp(self):
# Creates a task context with the correct testing paths.
initial_task_context = os.path.join(FLAGS.test_srcdir,
initial_task_context = os.path.join(test_flags.source_root(),
'syntaxnet/'
'testdata/context.pbtxt')
self._task_context = os.path.join(FLAGS.test_tmpdir, 'context.pbtxt')
self._task_context = os.path.join(test_flags.temp_dir(), 'context.pbtxt')
with open(initial_task_context, 'r') as fin:
with open(self._task_context, 'w') as fout:
fout.write(fin.read().replace('SRCDIR', FLAGS.test_srcdir)
.replace('OUTPATH', FLAGS.test_tmpdir))
fout.write(fin.read().replace('SRCDIR', test_flags.source_root())
.replace('OUTPATH', test_flags.temp_dir()))
# Creates necessary term maps.
with self.test_session() as sess:
......@@ -225,4 +218,4 @@ class ParsingReaderOpsTest(test_util.TensorFlowTestCase):
if __name__ == '__main__':
googletest.main()
tf.test.main()
......@@ -73,6 +73,7 @@ limitations under the License.
#include "syntaxnet/registry.h"
#include "syntaxnet/utils.h"
#include "absl/base/macros.h"
// =====================================================================
// Registry for accessing CharProperties by name
......@@ -128,7 +129,7 @@ struct CharPropertyWrapper : RegisterableClass<CharPropertyWrapper> {
static const int k_##name##_unicodes[] = {unicodes}; \
static utils::LazyStaticPtr<CharProperty, const char *, const int *, size_t> \
name##_char_property = {#name, k_##name##_unicodes, \
arraysize(k_##name##_unicodes)}; \
ABSL_ARRAYSIZE(k_##name##_unicodes)}; \
REGISTER_CHAR_PROPERTY(name##_char_property, name); \
DEFINE_IS_X_CHAR_PROPERTY_FUNCTIONS(name##_char_property, name)
......
......@@ -187,7 +187,7 @@ DEFINE_CHAR_PROPERTY(test_punctuation_plus, prop) {
prop->AddCharRange('b', 'b');
prop->AddCharRange('c', 'e');
static const int kUnicodes[] = {'f', RANGE('g', 'i'), 'j'};
prop->AddCharSpec(kUnicodes, arraysize(kUnicodes));
prop->AddCharSpec(kUnicodes, ABSL_ARRAYSIZE(kUnicodes));
prop->AddCharProperty("punctuation");
}
......@@ -223,25 +223,25 @@ const char32 kTestPunctuationPlusExtras[] = {
//
TEST_F(CharPropertiesTest, TestDigit) {
CollectArray(kTestDigit, arraysize(kTestDigit));
CollectArray(kTestDigit, ABSL_ARRAYSIZE(kTestDigit));
ExpectCharPropertyEqualsCollectedSet("test_digit");
}
TEST_F(CharPropertiesTest, TestWavyDash) {
CollectArray(kTestWavyDash, arraysize(kTestWavyDash));
CollectArray(kTestWavyDash, ABSL_ARRAYSIZE(kTestWavyDash));
ExpectCharPropertyEqualsCollectedSet("test_wavy_dash");
}
TEST_F(CharPropertiesTest, TestDigitOrWavyDash) {
CollectArray(kTestDigit, arraysize(kTestDigit));
CollectArray(kTestWavyDash, arraysize(kTestWavyDash));
CollectArray(kTestDigit, ABSL_ARRAYSIZE(kTestDigit));
CollectArray(kTestWavyDash, ABSL_ARRAYSIZE(kTestWavyDash));
ExpectCharPropertyEqualsCollectedSet("test_digit_or_wavy_dash");
}
TEST_F(CharPropertiesTest, TestPunctuationPlus) {
CollectCharProperty("punctuation");
CollectArray(kTestPunctuationPlusExtras,
arraysize(kTestPunctuationPlusExtras));
ABSL_ARRAYSIZE(kTestPunctuationPlusExtras));
ExpectCharPropertyEqualsCollectedSet("test_punctuation_plus");
}
......
......@@ -110,7 +110,9 @@ bool CharShiftTransitionState::IsTokenEnd(int i) const {
}
void CharShiftTransitionSystem::Setup(TaskContext *context) {
left_to_right_ = context->Get("left-to-right", true);
// The version with underscores takes precedence if explicitly set.
left_to_right_ =
context->Get("left_to_right", context->Get("left-to-right", true));
}
bool CharShiftTransitionSystem::IsAllowedAction(
......
......@@ -76,7 +76,7 @@ class CharShiftTransitionTest : public ::testing::Test {
}
void PrepareCharTransition(bool left_to_right) {
context_.SetParameter("left-to-right", left_to_right ? "true" : "false");
context_.SetParameter("left_to_right", left_to_right ? "true" : "false");
transition_system_.reset(ParserTransitionSystem::Create("char-shift-only"));
transition_system_->Setup(&context_);
......@@ -88,7 +88,7 @@ class CharShiftTransitionTest : public ::testing::Test {
}
void PrepareShiftTransition(bool left_to_right) {
context_.SetParameter("left-to-right", left_to_right ? "true" : "false");
context_.SetParameter("left_to_right", left_to_right ? "true" : "false");
transition_system_.reset(ParserTransitionSystem::Create("shift-only"));
transition_system_->Setup(&context_);
state_.reset(new ParserState(
......
......@@ -17,6 +17,8 @@
import collections
import re
from absl import app
from absl import flags
import asciitree
import tensorflow as tf
......@@ -26,7 +28,6 @@ from tensorflow.python.platform import tf_logging as logging
from syntaxnet import sentence_pb2
from syntaxnet.ops import gen_parser_ops
flags = tf.app.flags
FLAGS = flags.FLAGS
flags.DEFINE_string('task_context',
......@@ -88,16 +89,16 @@ def main(unused_argv):
sentence.ParseFromString(d)
tr = asciitree.LeftAligned()
d = to_dict(sentence)
print('Input: %s' % sentence.text)
print('Parse:')
print 'Input: %s' % sentence.text
print 'Parse:'
tr_str = tr(d)
pat = re.compile(r'\s*@\d+$')
for tr_ln in tr_str.splitlines():
print(pat.sub('', tr_ln))
print pat.sub('', tr_ln)
if finished:
break
if __name__ == '__main__':
tf.app.run()
app.run(main)
......@@ -101,6 +101,13 @@ int GenericFeatureFunction::GetIntParameter(const string &name,
tensorflow::strings::safe_strto32);
}
double GenericFeatureFunction::GetFloatParameter(const string &name,
double default_value) const {
const string value = GetParameter(name);
return utils::ParseUsing<double>(value, default_value,
tensorflow::strings::safe_strtod);
}
bool GenericFeatureFunction::GetBoolParameter(const string &name,
bool default_value) const {
const string value = GetParameter(name);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment