Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
ea3fa4a3
Commit
ea3fa4a3
authored
Mar 22, 2017
by
Ivan Bogatyy
Browse files
Update DRAGNN, fix some macOS issues
parent
b7523ee5
Changes
115
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1688 additions
and
19 deletions
+1688
-19
syntaxnet/dragnn/python/visualization.py
syntaxnet/dragnn/python/visualization.py
+15
-0
syntaxnet/dragnn/python/visualization_test.py
syntaxnet/dragnn/python/visualization_test.py
+15
-0
syntaxnet/dragnn/python/wrapped_units.py
syntaxnet/dragnn/python/wrapped_units.py
+41
-17
syntaxnet/dragnn/tools/BUILD
syntaxnet/dragnn/tools/BUILD
+38
-1
syntaxnet/dragnn/tools/build_pip_package.py
syntaxnet/dragnn/tools/build_pip_package.py
+15
-0
syntaxnet/dragnn/tools/evaluator.py
syntaxnet/dragnn/tools/evaluator.py
+21
-0
syntaxnet/dragnn/tools/model_trainer.py
syntaxnet/dragnn/tools/model_trainer.py
+197
-0
syntaxnet/dragnn/tools/model_trainer_test.sh
syntaxnet/dragnn/tools/model_trainer_test.sh
+54
-0
syntaxnet/dragnn/tools/oss_notebook_launcher.py
syntaxnet/dragnn/tools/oss_notebook_launcher.py
+15
-0
syntaxnet/dragnn/tools/parse-to-conll.py
syntaxnet/dragnn/tools/parse-to-conll.py
+15
-0
syntaxnet/dragnn/tools/parser_trainer.py
syntaxnet/dragnn/tools/parser_trainer.py
+0
-1
syntaxnet/dragnn/tools/segmenter-evaluator.py
syntaxnet/dragnn/tools/segmenter-evaluator.py
+15
-0
syntaxnet/dragnn/tools/testdata/biaffine.model/config.txt
syntaxnet/dragnn/tools/testdata/biaffine.model/config.txt
+4
-0
syntaxnet/dragnn/tools/testdata/biaffine.model/hyperparameters.pbtxt
...ragnn/tools/testdata/biaffine.model/hyperparameters.pbtxt
+18
-0
syntaxnet/dragnn/tools/testdata/biaffine.model/master.pbtxt
syntaxnet/dragnn/tools/testdata/biaffine.model/master.pbtxt
+1135
-0
syntaxnet/dragnn/tools/testdata/biaffine.model/resources/category-map
...agnn/tools/testdata/biaffine.model/resources/category-map
+7
-0
syntaxnet/dragnn/tools/testdata/biaffine.model/resources/char-map
...t/dragnn/tools/testdata/biaffine.model/resources/char-map
+18
-0
syntaxnet/dragnn/tools/testdata/biaffine.model/resources/char-ngram-map
...nn/tools/testdata/biaffine.model/resources/char-ngram-map
+46
-0
syntaxnet/dragnn/tools/testdata/biaffine.model/resources/label-map
.../dragnn/tools/testdata/biaffine.model/resources/label-map
+8
-0
syntaxnet/dragnn/tools/testdata/biaffine.model/resources/lcword-map
...dragnn/tools/testdata/biaffine.model/resources/lcword-map
+11
-0
No files found.
syntaxnet/dragnn/python/visualization.py
View file @
ea3fa4a3
# Copyright 2017 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Helper library for visualizations.
"""Helper library for visualizations.
TODO(googleuser): Find a more reliable way to serve stuff from IPython
TODO(googleuser): Find a more reliable way to serve stuff from IPython
...
...
syntaxnet/dragnn/python/visualization_test.py
View file @
ea3fa4a3
# -*- coding: utf-8 -*-
# -*- coding: utf-8 -*-
# Copyright 2017 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for dragnn.python.visualization."""
"""Tests for dragnn.python.visualization."""
from
__future__
import
absolute_import
from
__future__
import
absolute_import
...
...
syntaxnet/dragnn/python/wrapped_units.py
View file @
ea3fa4a3
# Copyright 2017 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Network units wrapping TensorFlows' tf.contrib.rnn cells.
"""Network units wrapping TensorFlows' tf.contrib.rnn cells.
Please put all wrapping logic for tf.contrib.rnn in this module; this will help
Please put all wrapping logic for tf.contrib.rnn in this module; this will help
...
@@ -25,7 +40,7 @@ class BaseLSTMNetwork(dragnn.NetworkUnitInterface):
...
@@ -25,7 +40,7 @@ class BaseLSTMNetwork(dragnn.NetworkUnitInterface):
logits: Logits associated with component actions.
logits: Logits associated with component actions.
"""
"""
def
__init__
(
self
,
component
):
def
__init__
(
self
,
component
,
additional_attr_defaults
=
None
):
"""Initializes the LSTM base class.
"""Initializes the LSTM base class.
Parameters used:
Parameters used:
...
@@ -42,15 +57,18 @@ class BaseLSTMNetwork(dragnn.NetworkUnitInterface):
...
@@ -42,15 +57,18 @@ class BaseLSTMNetwork(dragnn.NetworkUnitInterface):
Args:
Args:
component: parent ComponentBuilderBase object.
component: parent ComponentBuilderBase object.
additional_attr_defaults: Additional attributes for use by derived class.
"""
"""
attr_defaults
=
additional_attr_defaults
or
{}
attr_defaults
.
update
({
'layer_norm'
:
True
,
'input_dropout_rate'
:
-
1.0
,
'recurrent_dropout_rate'
:
0.8
,
'hidden_layer_sizes'
:
'256'
,
})
self
.
_attrs
=
dragnn
.
get_attrs_with_defaults
(
self
.
_attrs
=
dragnn
.
get_attrs_with_defaults
(
component
.
spec
.
network_unit
.
parameters
,
component
.
spec
.
network_unit
.
parameters
,
defaults
=
{
defaults
=
attr_defaults
)
'layer_norm'
:
True
,
'input_dropout_rate'
:
-
1.0
,
'recurrent_dropout_rate'
:
0.8
,
'hidden_layer_sizes'
:
'256'
,
})
self
.
_hidden_layer_sizes
=
map
(
int
,
self
.
_hidden_layer_sizes
=
map
(
int
,
self
.
_attrs
[
'hidden_layer_sizes'
].
split
(
','
))
self
.
_attrs
[
'hidden_layer_sizes'
].
split
(
','
))
...
@@ -87,8 +105,7 @@ class BaseLSTMNetwork(dragnn.NetworkUnitInterface):
...
@@ -87,8 +105,7 @@ class BaseLSTMNetwork(dragnn.NetworkUnitInterface):
self
.
_params
.
append
(
self
.
_params
.
append
(
tf
.
get_variable
(
tf
.
get_variable
(
'weights_softmax'
,
[
last_layer_dim
,
component
.
num_actions
],
'weights_softmax'
,
[
last_layer_dim
,
component
.
num_actions
],
initializer
=
tf
.
random_normal_initializer
(
initializer
=
tf
.
random_normal_initializer
(
stddev
=
1e-4
)))
stddev
=
1e-4
,
seed
=
self
.
_seed
)))
self
.
_params
.
append
(
self
.
_params
.
append
(
tf
.
get_variable
(
tf
.
get_variable
(
'bias_softmax'
,
[
component
.
num_actions
],
'bias_softmax'
,
[
component
.
num_actions
],
...
@@ -116,14 +133,9 @@ class BaseLSTMNetwork(dragnn.NetworkUnitInterface):
...
@@ -116,14 +133,9 @@ class BaseLSTMNetwork(dragnn.NetworkUnitInterface):
"""Appends layers defined by the base class to the |hidden_layers|."""
"""Appends layers defined by the base class to the |hidden_layers|."""
last_layer
=
hidden_layers
[
-
1
]
last_layer
=
hidden_layers
[
-
1
]
# TODO(googleuser): Uncomment the version that uses component.get_variable()
# and delete the uses of tf.get_variable().
# logits = tf.nn.xw_plus_b(last_layer,
# self._component.get_variable('weights_softmax'),
# self._component.get_variable('bias_softmax'))
logits
=
tf
.
nn
.
xw_plus_b
(
last_layer
,
logits
=
tf
.
nn
.
xw_plus_b
(
last_layer
,
t
f
.
get_variable
(
'weights_softmax'
),
self
.
_componen
t
.
get_variable
(
'weights_softmax'
),
t
f
.
get_variable
(
'bias_softmax'
))
self
.
_componen
t
.
get_variable
(
'bias_softmax'
))
return
hidden_layers
+
[
last_layer
,
logits
]
return
hidden_layers
+
[
last_layer
,
logits
]
def
_create_cell
(
self
,
num_units
,
during_training
):
def
_create_cell
(
self
,
num_units
,
during_training
):
...
@@ -321,7 +333,18 @@ class BulkBiLSTMNetwork(BaseLSTMNetwork):
...
@@ -321,7 +333,18 @@ class BulkBiLSTMNetwork(BaseLSTMNetwork):
"""
"""
def
__init__
(
self
,
component
):
def
__init__
(
self
,
component
):
super
(
BulkBiLSTMNetwork
,
self
).
__init__
(
component
)
"""Initializes the bulk bi-LSTM.
Parameters used:
parallel_iterations (1): Parallelism of the underlying tf.while_loop().
Defaults to 1 thread to encourage deterministic behavior, but can be
increased to trade memory for speed.
Args:
component: parent ComponentBuilderBase object.
"""
super
(
BulkBiLSTMNetwork
,
self
).
__init__
(
component
,
additional_attr_defaults
=
{
'parallel_iterations'
:
1
})
check
.
In
(
'lengths'
,
self
.
_linked_feature_dims
,
check
.
In
(
'lengths'
,
self
.
_linked_feature_dims
,
'Missing required linked feature'
)
'Missing required linked feature'
)
...
@@ -426,6 +449,7 @@ class BulkBiLSTMNetwork(BaseLSTMNetwork):
...
@@ -426,6 +449,7 @@ class BulkBiLSTMNetwork(BaseLSTMNetwork):
initial_states_fw
=
initial_states_forward
,
initial_states_fw
=
initial_states_forward
,
initial_states_bw
=
initial_states_backward
,
initial_states_bw
=
initial_states_backward
,
sequence_length
=
lengths_s
,
sequence_length
=
lengths_s
,
parallel_iterations
=
self
.
_attrs
[
'parallel_iterations'
],
scope
=
scope
)
scope
=
scope
)
return
outputs_sxnxd
return
outputs_sxnxd
...
...
syntaxnet/dragnn/tools/BUILD
View file @
ea3fa4a3
package
(
default_visibility
=
[
"//visibility:public"
])
package
(
default_visibility
=
[
"//visibility:public"
])
filegroup
(
name
=
"testdata"
,
srcs
=
glob
([
"testdata/**"
]),
)
py_binary
(
py_binary
(
name
=
"evaluator"
,
name
=
"evaluator"
,
srcs
=
[
"evaluator.py"
],
srcs
=
[
"evaluator.py"
],
...
@@ -74,6 +79,26 @@ py_binary(
...
@@ -74,6 +79,26 @@ py_binary(
py_binary
(
py_binary
(
name
=
"segmenter_trainer"
,
name
=
"segmenter_trainer"
,
srcs
=
[
"segmenter_trainer.py"
],
srcs
=
[
"segmenter_trainer.py"
],
deps
=
[
"//dragnn/core:dragnn_bulk_ops"
,
"//dragnn/core:dragnn_ops"
,
"//dragnn/protos:spec_py_pb2"
,
"//dragnn/python:evaluation"
,
"//dragnn/python:graph_builder"
,
"//dragnn/python:load_dragnn_cc_impl_py"
,
"//dragnn/python:sentence_io"
,
"//dragnn/python:spec_builder"
,
"//dragnn/python:trainer_lib"
,
"//syntaxnet:load_parser_ops_py"
,
"//syntaxnet:parser_ops"
,
"@org_tensorflow//tensorflow:tensorflow_py"
,
"@org_tensorflow//tensorflow/core:protos_all_py"
,
],
)
py_binary
(
name
=
"model_trainer"
,
srcs
=
[
"model_trainer.py"
],
deps
=
[
deps
=
[
"//dragnn/core:dragnn_bulk_ops"
,
"//dragnn/core:dragnn_bulk_ops"
,
"//dragnn/core:dragnn_ops"
,
"//dragnn/core:dragnn_ops"
,
...
@@ -81,7 +106,6 @@ py_binary(
...
@@ -81,7 +106,6 @@ py_binary(
"//dragnn/python:dragnn_ops"
,
"//dragnn/python:dragnn_ops"
,
"//dragnn/python:evaluation"
,
"//dragnn/python:evaluation"
,
"//dragnn/python:graph_builder"
,
"//dragnn/python:graph_builder"
,
"//dragnn/python:lexicon"
,
"//dragnn/python:load_dragnn_cc_impl_py"
,
"//dragnn/python:load_dragnn_cc_impl_py"
,
"//dragnn/python:sentence_io"
,
"//dragnn/python:sentence_io"
,
"//dragnn/python:spec_builder"
,
"//dragnn/python:spec_builder"
,
...
@@ -90,11 +114,24 @@ py_binary(
...
@@ -90,11 +114,24 @@ py_binary(
"//syntaxnet:parser_ops"
,
"//syntaxnet:parser_ops"
,
"//syntaxnet:sentence_py_pb2"
,
"//syntaxnet:sentence_py_pb2"
,
"//syntaxnet:task_spec_py_pb2"
,
"//syntaxnet:task_spec_py_pb2"
,
"//syntaxnet/util:check"
,
"@org_tensorflow//tensorflow:tensorflow_py"
,
"@org_tensorflow//tensorflow:tensorflow_py"
,
"@org_tensorflow//tensorflow/core:protos_all_py"
,
"@org_tensorflow//tensorflow/core:protos_all_py"
,
],
],
)
)
sh_test
(
name
=
"model_trainer_test"
,
size
=
"medium"
,
srcs
=
[
"model_trainer_test.sh"
],
data
=
[
":model_trainer"
,
":testdata"
,
],
deps
=
[
],
)
# This is meant to be run inside the Docker image. In the OSS directory, run,
# This is meant to be run inside the Docker image. In the OSS directory, run,
#
#
# ./build_devel.sh bazel run //dragnn/python:oss_notebook_launcher
# ./build_devel.sh bazel run //dragnn/python:oss_notebook_launcher
...
...
syntaxnet/dragnn/tools/build_pip_package.py
View file @
ea3fa4a3
# Copyright 2017 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Builds a pip package suitable for redistribution.
"""Builds a pip package suitable for redistribution.
Adapted from tensorflow/tools/pip_package/build_pip_package.sh. This might have
Adapted from tensorflow/tools/pip_package/build_pip_package.sh. This might have
...
...
syntaxnet/dragnn/tools/evaluator.py
View file @
ea3fa4a3
# Copyright 2017 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
r
"""Runs a DRAGNN model on a given set of CoNLL-formatted sentences.
r
"""Runs a DRAGNN model on a given set of CoNLL-formatted sentences.
Sample invocation:
Sample invocation:
...
@@ -51,6 +66,9 @@ flags.DEFINE_integer('threads', 10, 'Number of threads used for intra- and '
...
@@ -51,6 +66,9 @@ flags.DEFINE_integer('threads', 10, 'Number of threads used for intra- and '
flags
.
DEFINE_string
(
'timeline_output_file'
,
''
,
'Path to save timeline to. '
flags
.
DEFINE_string
(
'timeline_output_file'
,
''
,
'Path to save timeline to. '
'If specified, the final iteration of the evaluation loop '
'If specified, the final iteration of the evaluation loop '
'will capture and save a TensorFlow timeline.'
)
'will capture and save a TensorFlow timeline.'
)
flags
.
DEFINE_string
(
'log_file'
,
''
,
'File path to write parser eval results.'
)
flags
.
DEFINE_string
(
'language_name'
,
'_'
,
'Name of language being parsed, '
'for logging.'
)
def
main
(
unused_argv
):
def
main
(
unused_argv
):
...
@@ -134,6 +152,9 @@ def main(unused_argv):
...
@@ -134,6 +152,9 @@ def main(unused_argv):
tf
.
logging
.
info
(
'Processed %d documents in %.2f seconds.'
,
tf
.
logging
.
info
(
'Processed %d documents in %.2f seconds.'
,
len
(
input_corpus
),
time
.
time
()
-
start_time
)
len
(
input_corpus
),
time
.
time
()
-
start_time
)
pos
,
uas
,
las
=
evaluation
.
calculate_parse_metrics
(
input_corpus
,
processed
)
pos
,
uas
,
las
=
evaluation
.
calculate_parse_metrics
(
input_corpus
,
processed
)
if
FLAGS
.
log_file
:
with
gfile
.
GFile
(
FLAGS
.
log_file
,
'w'
)
as
f
:
f
.
write
(
'%s
\t
%f
\t
%f
\t
%f
\n
'
%
(
FLAGS
.
language_name
,
pos
,
uas
,
las
))
if
FLAGS
.
output_file
:
if
FLAGS
.
output_file
:
with
gfile
.
GFile
(
FLAGS
.
output_file
,
'w'
)
as
f
:
with
gfile
.
GFile
(
FLAGS
.
output_file
,
'w'
)
as
f
:
...
...
syntaxnet/dragnn/tools/model_trainer.py
0 → 100755
View file @
ea3fa4a3
# Copyright 2017 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Trainer for generic DRAGNN models.
This trainer uses a "model directory" for both input and output. When invoked,
the model directory should contain the following inputs:
<model_dir>/config.txt: A stringified dict that defines high-level
configuration parameters. Unset parameters default to False.
<model_dir>/master.pbtxt: A text-format MasterSpec proto that defines
the DRAGNN network to train.
<model_dir>/hyperparameters.pbtxt: A text-format GridPoint proto that
defines training hyper-parameters.
<model_dir>/targets.pbtxt: (Optional) A text-format TrainingGridSpec whose
"target" field defines the training targets. If missing, then default
training targets are used instead.
On success, the model directory will contain the following outputs:
<model_dir>/checkpoints/best: The best checkpoint seen during training, as
measured by accuracy on the eval corpus.
<model_dir>/tensorboard: TensorBoard log directory.
Outside of the files and subdirectories named above, the model directory should
contain any other necessary files (e.g., pretrained embeddings). See the model
builders in dragnn/examples.
"""
import
ast
import
collections
import
os
import
os.path
import
tensorflow
as
tf
from
google.protobuf
import
text_format
from
dragnn.protos
import
spec_pb2
from
dragnn.python
import
evaluation
from
dragnn.python
import
graph_builder
from
dragnn.python
import
sentence_io
from
dragnn.python
import
spec_builder
from
dragnn.python
import
trainer_lib
from
syntaxnet.ops
import
gen_parser_ops
from
syntaxnet.util
import
check
import
dragnn.python.load_dragnn_cc_impl
import
syntaxnet.load_parser_ops
flags
=
tf
.
app
.
flags
FLAGS
=
flags
.
FLAGS
flags
.
DEFINE_string
(
'tf_master'
,
''
,
'TensorFlow execution engine to connect to.'
)
flags
.
DEFINE_string
(
'model_dir'
,
None
,
'Path to a prepared model directory.'
)
flags
.
DEFINE_string
(
'pretrain_steps'
,
None
,
'Comma-delimited list of pre-training steps per training target.'
)
flags
.
DEFINE_string
(
'pretrain_epochs'
,
None
,
'Comma-delimited list of pre-training epochs per training target.'
)
flags
.
DEFINE_string
(
'train_steps'
,
None
,
'Comma-delimited list of training steps per training target.'
)
flags
.
DEFINE_string
(
'train_epochs'
,
None
,
'Comma-delimited list of training epochs per training target.'
)
flags
.
DEFINE_integer
(
'batch_size'
,
4
,
'Batch size.'
)
flags
.
DEFINE_integer
(
'report_every'
,
200
,
'Report cost and training accuracy every this many steps.'
)
def
_read_text_proto
(
path
,
proto_type
):
"""Reads a text-format instance of |proto_type| from the |path|."""
proto
=
proto_type
()
with
tf
.
gfile
.
FastGFile
(
path
)
as
proto_file
:
text_format
.
Parse
(
proto_file
.
read
(),
proto
)
return
proto
def
_convert_to_char_corpus
(
corpus
):
"""Converts the word-based |corpus| into a char-based corpus."""
with
tf
.
Session
(
graph
=
tf
.
Graph
())
as
tmp_session
:
conversion_op
=
gen_parser_ops
.
segmenter_training_data_constructor
(
corpus
)
return
tmp_session
.
run
(
conversion_op
)
def
_get_steps
(
steps_flag
,
epochs_flag
,
corpus_length
):
"""Converts the |steps_flag| or |epochs_flag| into a list of step counts."""
if
steps_flag
:
return
map
(
int
,
steps_flag
.
split
(
','
))
return
[
corpus_length
*
int
(
epochs
)
for
epochs
in
epochs_flag
.
split
(
','
)]
def
main
(
unused_argv
):
tf
.
logging
.
set_verbosity
(
tf
.
logging
.
INFO
)
check
.
NotNone
(
FLAGS
.
model_dir
,
'--model_dir is required'
)
check
.
Ne
(
FLAGS
.
pretrain_steps
is
None
,
FLAGS
.
pretrain_epochs
is
None
,
'Exactly one of --pretrain_steps or --pretrain_epochs is required'
)
check
.
Ne
(
FLAGS
.
train_steps
is
None
,
FLAGS
.
train_epochs
is
None
,
'Exactly one of --train_steps or --train_epochs is required'
)
config_path
=
os
.
path
.
join
(
FLAGS
.
model_dir
,
'config.txt'
)
master_path
=
os
.
path
.
join
(
FLAGS
.
model_dir
,
'master.pbtxt'
)
hyperparameters_path
=
os
.
path
.
join
(
FLAGS
.
model_dir
,
'hyperparameters.pbtxt'
)
targets_path
=
os
.
path
.
join
(
FLAGS
.
model_dir
,
'targets.pbtxt'
)
checkpoint_path
=
os
.
path
.
join
(
FLAGS
.
model_dir
,
'checkpoints/best'
)
tensorboard_dir
=
os
.
path
.
join
(
FLAGS
.
model_dir
,
'tensorboard'
)
with
tf
.
gfile
.
FastGFile
(
config_path
)
as
config_file
:
config
=
collections
.
defaultdict
(
bool
,
ast
.
literal_eval
(
config_file
.
read
()))
train_corpus_path
=
config
[
'train_corpus_path'
]
tune_corpus_path
=
config
[
'tune_corpus_path'
]
projectivize_train_corpus
=
config
[
'projectivize_train_corpus'
]
master
=
_read_text_proto
(
master_path
,
spec_pb2
.
MasterSpec
)
hyperparameters
=
_read_text_proto
(
hyperparameters_path
,
spec_pb2
.
GridPoint
)
targets
=
spec_builder
.
default_targets_from_spec
(
master
)
if
tf
.
gfile
.
Exists
(
targets_path
):
targets
=
_read_text_proto
(
targets_path
,
spec_pb2
.
TrainingGridSpec
).
target
# Build the TensorFlow graph.
graph
=
tf
.
Graph
()
with
graph
.
as_default
():
tf
.
set_random_seed
(
hyperparameters
.
seed
)
builder
=
graph_builder
.
MasterBuilder
(
master
,
hyperparameters
)
trainers
=
[
builder
.
add_training_from_config
(
target
)
for
target
in
targets
]
annotator
=
builder
.
add_annotation
()
builder
.
add_saver
()
# Read in serialized protos from training data.
train_corpus
=
sentence_io
.
ConllSentenceReader
(
train_corpus_path
,
projectivize
=
projectivize_train_corpus
).
corpus
()
tune_corpus
=
sentence_io
.
ConllSentenceReader
(
tune_corpus_path
,
projectivize
=
False
).
corpus
()
gold_tune_corpus
=
tune_corpus
# Convert to char-based corpora, if requested.
if
config
[
'convert_to_char_corpora'
]:
# NB: Do not convert the |gold_tune_corpus|, which should remain word-based
# for segmentation evaluation purposes.
train_corpus
=
_convert_to_char_corpus
(
train_corpus
)
tune_corpus
=
_convert_to_char_corpus
(
tune_corpus
)
pretrain_steps
=
_get_steps
(
FLAGS
.
pretrain_steps
,
FLAGS
.
pretrain_epochs
,
len
(
train_corpus
))
train_steps
=
_get_steps
(
FLAGS
.
train_steps
,
FLAGS
.
train_epochs
,
len
(
train_corpus
))
check
.
Eq
(
len
(
targets
),
len
(
pretrain_steps
),
'Length mismatch between training targets and --pretrain_steps'
)
check
.
Eq
(
len
(
targets
),
len
(
train_steps
),
'Length mismatch between training targets and --train_steps'
)
# Ready to train!
tf
.
logging
.
info
(
'Training on %d sentences.'
,
len
(
train_corpus
))
tf
.
logging
.
info
(
'Tuning on %d sentences.'
,
len
(
tune_corpus
))
tf
.
logging
.
info
(
'Creating TensorFlow checkpoint dir...'
)
summary_writer
=
trainer_lib
.
get_summary_writer
(
tensorboard_dir
)
checkpoint_dir
=
os
.
path
.
dirname
(
checkpoint_path
)
if
tf
.
gfile
.
IsDirectory
(
checkpoint_dir
):
tf
.
gfile
.
DeleteRecursively
(
checkpoint_dir
)
elif
tf
.
gfile
.
Exists
(
checkpoint_dir
):
tf
.
gfile
.
Remove
(
checkpoint_dir
)
tf
.
gfile
.
MakeDirs
(
checkpoint_dir
)
with
tf
.
Session
(
FLAGS
.
tf_master
,
graph
=
graph
)
as
sess
:
# Make sure to re-initialize all underlying state.
sess
.
run
(
tf
.
global_variables_initializer
())
trainer_lib
.
run_training
(
sess
,
trainers
,
annotator
,
evaluation
.
parser_summaries
,
pretrain_steps
,
train_steps
,
train_corpus
,
tune_corpus
,
gold_tune_corpus
,
FLAGS
.
batch_size
,
summary_writer
,
FLAGS
.
report_every
,
builder
.
saver
,
checkpoint_path
)
tf
.
logging
.
info
(
'Best checkpoint written to:
\n
%s'
,
checkpoint_path
)
if
__name__
==
'__main__'
:
tf
.
app
.
run
()
syntaxnet/dragnn/tools/model_trainer_test.sh
0 → 100755
View file @
ea3fa4a3
#!/bin/bash
# Copyright 2017 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# This test runs the model trainer on a snapshotted model directory. This is a
# "don't crash" test, so it does not evaluate the trained model.
set
-eu
readonly
DRAGNN_DIR
=
"
${
TEST_SRCDIR
}
/
${
TEST_WORKSPACE
}
/dragnn"
readonly
MODEL_TRAINER
=
"
${
DRAGNN_DIR
}
/tools/model_trainer"
readonly
MODEL_DIR
=
"
${
DRAGNN_DIR
}
/tools/testdata/biaffine.model"
readonly
CORPUS
=
"
${
DRAGNN_DIR
}
/tools/testdata/small.conll"
readonly
TMP_DIR
=
"/tmp/model_trainer_test.
$$
"
readonly
TMP_MODEL_DIR
=
"
${
TMP_DIR
}
/biaffine.model"
rm
-rf
"
${
TMP_DIR
}
"
mkdir
-p
"
${
TMP_DIR
}
"
# Copy all testdata files to a temp dir, so they can be modified (see below).
cp
"
${
CORPUS
}
"
"
${
TMP_DIR
}
"
mkdir
-p
"
${
TMP_MODEL_DIR
}
"
for
name
in
hyperparameters.pbtxt targets.pbtxt resources
;
do
cp
-r
"
${
MODEL_DIR
}
/
${
name
}
"
"
${
TMP_MODEL_DIR
}
/
${
name
}
"
done
# Replace "TESTDATA" with the temp dir path in config files that contain paths.
for
name
in
config.txt master.pbtxt
;
do
sed
"s=TESTDATA=
${
TMP_DIR
}
="
"
${
MODEL_DIR
}
/
${
name
}
"
\
>
"
${
TMP_MODEL_DIR
}
/
${
name
}
"
done
"
${
MODEL_TRAINER
}
"
\
--model_dir
=
"
${
TMP_MODEL_DIR
}
"
\
--pretrain_steps
=
'1'
\
--train_epochs
=
'10'
\
--alsologtostderr
echo
"PASS"
syntaxnet/dragnn/tools/oss_notebook_launcher.py
View file @
ea3fa4a3
# Copyright 2017 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Mini OSS launcher so we can build a py_binary for OSS."""
"""Mini OSS launcher so we can build a py_binary for OSS."""
from
__future__
import
absolute_import
from
__future__
import
absolute_import
...
...
syntaxnet/dragnn/tools/parse-to-conll.py
View file @
ea3fa4a3
# Copyright 2017 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
r
"""Runs a both a segmentation and parsing model on a CoNLL dataset.
r
"""Runs a both a segmentation and parsing model on a CoNLL dataset.
"""
"""
...
...
syntaxnet/dragnn/tools/parser_trainer.py
View file @
ea3fa4a3
...
@@ -60,7 +60,6 @@ flags.DEFINE_string('dev_corpus_path', '', 'Path to development set data.')
...
@@ -60,7 +60,6 @@ flags.DEFINE_string('dev_corpus_path', '', 'Path to development set data.')
flags
.
DEFINE_bool
(
'compute_lexicon'
,
False
,
''
)
flags
.
DEFINE_bool
(
'compute_lexicon'
,
False
,
''
)
flags
.
DEFINE_bool
(
'projectivize_training_set'
,
True
,
''
)
flags
.
DEFINE_bool
(
'projectivize_training_set'
,
True
,
''
)
flags
.
DEFINE_integer
(
'num_epochs'
,
10
,
'Number of epochs to train for.'
)
flags
.
DEFINE_integer
(
'batch_size'
,
4
,
'Batch size.'
)
flags
.
DEFINE_integer
(
'batch_size'
,
4
,
'Batch size.'
)
flags
.
DEFINE_integer
(
'report_every'
,
200
,
flags
.
DEFINE_integer
(
'report_every'
,
200
,
'Report cost and training accuracy every this many steps.'
)
'Report cost and training accuracy every this many steps.'
)
...
...
syntaxnet/dragnn/tools/segmenter-evaluator.py
View file @
ea3fa4a3
# Copyright 2017 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
r
"""Runs a DRAGNN model on a given set of CoNLL-formatted sentences.
r
"""Runs a DRAGNN model on a given set of CoNLL-formatted sentences.
Sample invocation:
Sample invocation:
...
...
syntaxnet/dragnn/tools/testdata/biaffine.model/config.txt
0 → 100644
View file @
ea3fa4a3
{
'train_corpus_path': 'TESTDATA/small.conll',
'tune_corpus_path': 'TESTDATA/small.conll',
}
syntaxnet/dragnn/tools/testdata/biaffine.model/hyperparameters.pbtxt
0 → 100644
View file @
ea3fa4a3
learning_method: "adam"
adam_beta1: 0.9
adam_beta2: 0.9
adam_eps: 1e-12
learning_rate: 0.002
decay_base: 0.75
decay_staircase: false
decay_steps: 2500
dropout_rate: 0.67
recurrent_dropout_rate: 0.75
gradient_clip_norm: 15
l2_regularization_coefficient: 0
use_moving_average: false
seed: 1
syntaxnet/dragnn/tools/testdata/biaffine.model/master.pbtxt
0 → 100644
View file @
ea3fa4a3
This diff is collapsed.
Click to expand it.
syntaxnet/dragnn/tools/testdata/biaffine.model/resources/category-map
0 → 100644
View file @
ea3fa4a3
6
VERB 6
NOUN 5
PRON 5
PUNCT 5
DET 2
CONJ 1
syntaxnet/dragnn/tools/testdata/biaffine.model/resources/char-map
0 → 100644
View file @
ea3fa4a3
17
o 10
e 8
b 6
s 6
. 5
h 5
l 5
y 5
k 4
T 3
a 3
n 3
u 3
I 2
v 2
c 1
d 1
syntaxnet/dragnn/tools/testdata/biaffine.model/resources/char-ngram-map
0 → 100644
View file @
ea3fa4a3
45
o 8
^ b 6
^ . $ 5
e 5
y $ 5
^ bo 4
k 4
ks $ 4
ok 4
oo 4
s $ 4
^ T 3
^ Th 3
e $ 3
ey $ 3
h 3
he 3
l 3
u 3
^ I $ 2
^ bu 2
^ h 2
^ ha 2
^ n 2
^ no $ 2
^ s 2
^ se 2
a 2
av 2
el 2
l $ 2
ll $ 2
o $ 2
uy $ 2
v 2
ve $ 2
^ a 1
^ an 1
^ c 1
^ cl 1
d $ 1
lu 1
n 1
nd $ 1
ue $ 1
syntaxnet/dragnn/tools/testdata/biaffine.model/resources/label-map
0 → 100644
View file @
ea3fa4a3
7
ROOT 5
nsubj 5
obj 5
punct 5
det 2
cc 1
conj 1
syntaxnet/dragnn/tools/testdata/biaffine.model/resources/lcword-map
0 → 100644
View file @
ea3fa4a3
10
. 5
books 4
they 3
buy 2
have 2
i 2
no 2
sell 2
and 1
clue 1
Prev
1
2
3
4
5
6
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment