Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
277f99c7
Commit
277f99c7
authored
Mar 23, 2017
by
Ivan Bogatyy
Committed by
GitHub
Mar 23, 2017
Browse files
Merge pull request #1243 from bogatyy/master
Add license headers, fix some macOS issues
parents
f7cea8d0
ea3fa4a3
Changes
115
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
403 additions
and
78 deletions
+403
-78
syntaxnet/dragnn/python/bulk_component.py
syntaxnet/dragnn/python/bulk_component.py
+19
-2
syntaxnet/dragnn/python/bulk_component_test.py
syntaxnet/dragnn/python/bulk_component_test.py
+15
-0
syntaxnet/dragnn/python/component.py
syntaxnet/dragnn/python/component.py
+43
-3
syntaxnet/dragnn/python/composite_optimizer.py
syntaxnet/dragnn/python/composite_optimizer.py
+15
-0
syntaxnet/dragnn/python/composite_optimizer_test.py
syntaxnet/dragnn/python/composite_optimizer_test.py
+16
-2
syntaxnet/dragnn/python/digraph_ops.py
syntaxnet/dragnn/python/digraph_ops.py
+15
-0
syntaxnet/dragnn/python/digraph_ops_test.py
syntaxnet/dragnn/python/digraph_ops_test.py
+15
-0
syntaxnet/dragnn/python/dragnn_ops.py
syntaxnet/dragnn/python/dragnn_ops.py
+15
-0
syntaxnet/dragnn/python/graph_builder.py
syntaxnet/dragnn/python/graph_builder.py
+49
-27
syntaxnet/dragnn/python/graph_builder_test.py
syntaxnet/dragnn/python/graph_builder_test.py
+32
-0
syntaxnet/dragnn/python/network_units.py
syntaxnet/dragnn/python/network_units.py
+41
-40
syntaxnet/dragnn/python/network_units_test.py
syntaxnet/dragnn/python/network_units_test.py
+15
-0
syntaxnet/dragnn/python/render_parse_tree_graphviz.py
syntaxnet/dragnn/python/render_parse_tree_graphviz.py
+15
-0
syntaxnet/dragnn/python/render_parse_tree_graphviz_test.py
syntaxnet/dragnn/python/render_parse_tree_graphviz_test.py
+15
-0
syntaxnet/dragnn/python/render_spec_with_graphviz.py
syntaxnet/dragnn/python/render_spec_with_graphviz.py
+15
-0
syntaxnet/dragnn/python/render_spec_with_graphviz_test.py
syntaxnet/dragnn/python/render_spec_with_graphviz_test.py
+15
-0
syntaxnet/dragnn/python/sentence_io.py
syntaxnet/dragnn/python/sentence_io.py
+15
-0
syntaxnet/dragnn/python/sentence_io_test.py
syntaxnet/dragnn/python/sentence_io_test.py
+15
-0
syntaxnet/dragnn/python/spec_builder.py
syntaxnet/dragnn/python/spec_builder.py
+6
-2
syntaxnet/dragnn/python/trainer_lib.py
syntaxnet/dragnn/python/trainer_lib.py
+17
-2
No files found.
syntaxnet/dragnn/python/bulk_component.py
View file @
277f99c7
# Copyright 2017 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Component builders for non-recurrent networks in DRAGNN."""
"""Component builders for non-recurrent networks in DRAGNN."""
...
@@ -249,7 +264,8 @@ class BulkFeatureExtractorComponentBuilder(component.ComponentBuilderBase):
...
@@ -249,7 +264,8 @@ class BulkFeatureExtractorComponentBuilder(component.ComponentBuilderBase):
update_network_states
(
self
,
tensors
,
network_states
,
stride
)
update_network_states
(
self
,
tensors
,
network_states
,
stride
)
cost
=
self
.
add_regularizer
(
tf
.
constant
(
0.
))
cost
=
self
.
add_regularizer
(
tf
.
constant
(
0.
))
return
state
.
handle
,
cost
,
0
,
0
correct
,
total
=
tf
.
constant
(
0
),
tf
.
constant
(
0
)
return
state
.
handle
,
cost
,
correct
,
total
def
build_greedy_inference
(
self
,
state
,
network_states
,
def
build_greedy_inference
(
self
,
state
,
network_states
,
during_training
=
False
):
during_training
=
False
):
...
@@ -327,7 +343,8 @@ class BulkFeatureIdExtractorComponentBuilder(component.ComponentBuilderBase):
...
@@ -327,7 +343,8 @@ class BulkFeatureIdExtractorComponentBuilder(component.ComponentBuilderBase):
"""See base class."""
"""See base class."""
state
.
handle
=
self
.
_extract_feature_ids
(
state
,
network_states
,
True
)
state
.
handle
=
self
.
_extract_feature_ids
(
state
,
network_states
,
True
)
cost
=
self
.
add_regularizer
(
tf
.
constant
(
0.
))
cost
=
self
.
add_regularizer
(
tf
.
constant
(
0.
))
return
state
.
handle
,
cost
,
0
,
0
correct
,
total
=
tf
.
constant
(
0
),
tf
.
constant
(
0
)
return
state
.
handle
,
cost
,
correct
,
total
def
build_greedy_inference
(
self
,
state
,
network_states
,
def
build_greedy_inference
(
self
,
state
,
network_states
,
during_training
=
False
):
during_training
=
False
):
...
...
syntaxnet/dragnn/python/bulk_component_test.py
View file @
277f99c7
# Copyright 2017 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for bulk_component.
"""Tests for bulk_component.
Verifies that:
Verifies that:
...
...
syntaxnet/dragnn/python/component.py
View file @
277f99c7
# Copyright 2017 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Builds a DRAGNN graph for local training."""
"""Builds a DRAGNN graph for local training."""
from
abc
import
ABCMeta
from
abc
import
ABCMeta
...
@@ -147,6 +162,32 @@ class ComponentBuilderBase(object):
...
@@ -147,6 +162,32 @@ class ComponentBuilderBase(object):
"""
"""
pass
pass
def
build_structured_training
(
self
,
state
,
network_states
):
"""Builds a beam search based training loop for this component.
The default implementation builds a dummy graph and raises a
TensorFlow runtime exception to indicate that structured training
is not implemented.
Args:
state: MasterState from the 'AdvanceMaster' op that advances the
underlying master to this component.
network_states: dictionary of component NetworkState objects.
Returns:
(handle, cost, correct, total) -- These are TF ops corresponding
to the final handle after unrolling, the total cost, and the
total number of actions. Since the number of correctly predicted
actions is not applicable in the structured training setting, a
dummy value should returned.
"""
del
network_states
# Unused.
with
tf
.
control_dependencies
([
tf
.
Assert
(
False
,
[
'Not implemented.'
])]):
handle
=
tf
.
identity
(
state
.
handle
)
cost
=
tf
.
constant
(
0.
)
correct
,
total
=
tf
.
constant
(
0
),
tf
.
constant
(
0
)
return
handle
,
cost
,
correct
,
total
@
abstractmethod
@
abstractmethod
def
build_greedy_inference
(
self
,
state
,
network_states
,
def
build_greedy_inference
(
self
,
state
,
network_states
,
during_training
=
False
):
during_training
=
False
):
...
@@ -349,14 +390,13 @@ class DynamicComponentBuilder(ComponentBuilderBase):
...
@@ -349,14 +390,13 @@ class DynamicComponentBuilder(ComponentBuilderBase):
correctly predicted actions, and the total number of actions.
correctly predicted actions, and the total number of actions.
"""
"""
logging
.
info
(
'Building component: %s'
,
self
.
spec
.
name
)
logging
.
info
(
'Building component: %s'
,
self
.
spec
.
name
)
with
tf
.
control_dependencies
([
tf
.
assert_equal
(
self
.
training_beam_size
,
1
)]):
stride
=
state
.
current_batch_size
*
self
.
training_beam_size
stride
=
state
.
current_batch_size
*
self
.
training_beam_size
cost
=
tf
.
constant
(
0.
)
cost
=
tf
.
constant
(
0.
)
correct
=
tf
.
constant
(
0
)
correct
=
tf
.
constant
(
0
)
total
=
tf
.
constant
(
0
)
total
=
tf
.
constant
(
0
)
# Create the TensorArray's to store activations for downstream/recurrent
# connections.
def
cond
(
handle
,
*
_
):
def
cond
(
handle
,
*
_
):
all_final
=
dragnn_ops
.
emit_all_final
(
handle
,
component
=
self
.
name
)
all_final
=
dragnn_ops
.
emit_all_final
(
handle
,
component
=
self
.
name
)
return
tf
.
logical_not
(
tf
.
reduce_all
(
all_final
))
return
tf
.
logical_not
(
tf
.
reduce_all
(
all_final
))
...
...
syntaxnet/dragnn/python/composite_optimizer.py
View file @
277f99c7
# Copyright 2017 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""An optimizer that switches between several methods."""
"""An optimizer that switches between several methods."""
import
tensorflow
as
tf
import
tensorflow
as
tf
...
...
syntaxnet/dragnn/python/composite_optimizer_test.py
View file @
277f99c7
"""Tests for CompositeOptimizer.
# Copyright 2017 Google Inc. All Rights Reserved.
"""
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for CompositeOptimizer."""
import
numpy
as
np
import
numpy
as
np
...
...
syntaxnet/dragnn/python/digraph_ops.py
View file @
277f99c7
# Copyright 2017 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""TensorFlow ops for directed graphs."""
"""TensorFlow ops for directed graphs."""
import
tensorflow
as
tf
import
tensorflow
as
tf
...
...
syntaxnet/dragnn/python/digraph_ops_test.py
View file @
277f99c7
# Copyright 2017 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for digraph ops."""
"""Tests for digraph ops."""
import
tensorflow
as
tf
import
tensorflow
as
tf
...
...
syntaxnet/dragnn/python/dragnn_ops.py
View file @
277f99c7
# Copyright 2017 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Groups the DRAGNN TensorFlow ops in one module."""
"""Groups the DRAGNN TensorFlow ops in one module."""
...
...
syntaxnet/dragnn/python/graph_builder.py
View file @
277f99c7
# Copyright 2017 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Builds a DRAGNN graph for local training."""
"""Builds a DRAGNN graph for local training."""
...
@@ -65,6 +80,13 @@ def _create_optimizer(hyperparams, learning_rate_var, step_var=None):
...
@@ -65,6 +80,13 @@ def _create_optimizer(hyperparams, learning_rate_var, step_var=None):
beta2
=
hyperparams
.
adam_beta2
,
beta2
=
hyperparams
.
adam_beta2
,
epsilon
=
hyperparams
.
adam_eps
,
epsilon
=
hyperparams
.
adam_eps
,
use_locking
=
True
)
use_locking
=
True
)
elif
hyperparams
.
learning_method
==
'lazyadam'
:
return
tf
.
contrib
.
opt
.
LazyAdamOptimizer
(
learning_rate_var
,
beta1
=
hyperparams
.
adam_beta1
,
beta2
=
hyperparams
.
adam_beta2
,
epsilon
=
hyperparams
.
adam_eps
,
use_locking
=
True
)
elif
hyperparams
.
learning_method
==
'momentum'
:
elif
hyperparams
.
learning_method
==
'momentum'
:
return
tf
.
train
.
MomentumOptimizer
(
return
tf
.
train
.
MomentumOptimizer
(
learning_rate_var
,
hyperparams
.
momentum
,
use_locking
=
True
)
learning_rate_var
,
hyperparams
.
momentum
,
use_locking
=
True
)
...
@@ -138,6 +160,10 @@ class MasterBuilder(object):
...
@@ -138,6 +160,10 @@ class MasterBuilder(object):
if
hyperparam_config
is
None
else
hyperparam_config
)
if
hyperparam_config
is
None
else
hyperparam_config
)
self
.
pool_scope
=
pool_scope
self
.
pool_scope
=
pool_scope
# Set the graph-level random seed before creating the Components so the ops
# they create will use this seed.
tf
.
set_random_seed
(
hyperparam_config
.
seed
)
# Construct all utility class and variables for each Component.
# Construct all utility class and variables for each Component.
self
.
components
=
[]
self
.
components
=
[]
self
.
lookup_component
=
{}
self
.
lookup_component
=
{}
...
@@ -318,15 +344,18 @@ class MasterBuilder(object):
...
@@ -318,15 +344,18 @@ class MasterBuilder(object):
dragnn_ops
.
batch_size
(
dragnn_ops
.
batch_size
(
handle
,
component
=
comp
.
name
))
handle
,
component
=
comp
.
name
))
with
tf
.
control_dependencies
([
handle
,
cost
]):
with
tf
.
control_dependencies
([
handle
,
cost
]):
component_cost
=
tf
.
constant
(
0.
)
args
=
(
master_state
,
network_states
)
component_correct
=
tf
.
constant
(
0
)
component_total
=
tf
.
constant
(
0
)
if
unroll_using_oracle
[
component_index
]:
if
unroll_using_oracle
[
component_index
]:
handle
,
component_cost
,
component_correct
,
component_total
=
(
comp
.
build_greedy_training
(
master_state
,
network_states
))
handle
,
component_cost
,
component_correct
,
component_total
=
(
tf
.
cond
(
comp
.
training_beam_size
>
1
,
lambda
:
comp
.
build_structured_training
(
*
args
),
lambda
:
comp
.
build_greedy_training
(
*
args
)))
else
:
else
:
handle
=
comp
.
build_greedy_inference
(
handle
=
comp
.
build_greedy_inference
(
*
args
,
during_training
=
True
)
master_state
,
network_states
,
during_training
=
True
)
component_cost
=
tf
.
constant
(
0.
)
component_correct
,
component_total
=
tf
.
constant
(
0
),
tf
.
constant
(
0
)
weighted_component_cost
=
tf
.
multiply
(
weighted_component_cost
=
tf
.
multiply
(
component_cost
,
component_cost
,
...
@@ -497,30 +526,23 @@ class MasterBuilder(object):
...
@@ -497,30 +526,23 @@ class MasterBuilder(object):
with
tf
.
name_scope
(
scope_id
):
with
tf
.
name_scope
(
scope_id
):
# Construct training targets. Disable tracing during training.
# Construct training targets. Disable tracing during training.
handle
,
input_batch
=
self
.
_get_session_with_reader
(
trace_only
)
handle
,
input_batch
=
self
.
_get_session_with_reader
(
trace_only
)
if
trace_only
:
# Build a training graph that doesn't have any side effects.
# If `trace_only` is True, the training graph shouldn't have any
# side effects. Otherwise, the standard training scenario should
# generate gradients and update counters.
handle
,
outputs
=
self
.
build_training
(
handle
,
outputs
=
self
.
build_training
(
handle
,
handle
,
compute_gradients
=
False
,
compute_gradients
=
not
trace_only
,
advance_counters
=
False
,
advance_counters
=
not
trace_only
,
component_weights
=
target_config
.
component_weights
,
component_weights
=
target_config
.
component_weights
,
unroll_using_oracle
=
target_config
.
unroll_using_oracle
,
unroll_using_oracle
=
target_config
.
unroll_using_oracle
,
max_index
=
target_config
.
max_index
,
max_index
=
target_config
.
max_index
,
**
kwargs
)
**
kwargs
)
if
trace_only
:
outputs
[
'traces'
]
=
dragnn_ops
.
get_component_trace
(
outputs
[
'traces'
]
=
dragnn_ops
.
get_component_trace
(
handle
,
component
=
self
.
spec
.
component
[
-
1
].
name
)
handle
,
component
=
self
.
spec
.
component
[
-
1
].
name
)
else
:
else
:
# The standard training scenario has gradients and updates counters.
# Standard training keeps track of the number of training steps.
handle
,
outputs
=
self
.
build_training
(
handle
,
compute_gradients
=
True
,
advance_counters
=
True
,
component_weights
=
target_config
.
component_weights
,
unroll_using_oracle
=
target_config
.
unroll_using_oracle
,
max_index
=
target_config
.
max_index
,
**
kwargs
)
# In addition, it keeps track of the number of training steps.
outputs
[
'target_step'
]
=
tf
.
get_variable
(
outputs
[
'target_step'
]
=
tf
.
get_variable
(
scope_id
+
'/TargetStep'
,
[],
scope_id
+
'/TargetStep'
,
[],
initializer
=
tf
.
zeros_initializer
(),
initializer
=
tf
.
zeros_initializer
(),
...
...
syntaxnet/dragnn/python/graph_builder_test.py
View file @
277f99c7
# Copyright 2017 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for graph_builder."""
"""Tests for graph_builder."""
...
@@ -517,6 +532,23 @@ class GraphBuilderTest(test_util.TensorFlowTestCase):
...
@@ -517,6 +532,23 @@ class GraphBuilderTest(test_util.TensorFlowTestCase):
expected_num_actions
=
9
,
expected_num_actions
=
9
,
expected
=
_TAGGER_PARSER_EXPECTED_SENTENCES
)
expected
=
_TAGGER_PARSER_EXPECTED_SENTENCES
)
def
testStructuredTrainingNotImplementedDeath
(
self
):
spec
=
self
.
LoadSpec
(
'simple_parser_master_spec.textproto'
)
# Make the 'parser' component have a beam at training time.
self
.
assertEqual
(
'parser'
,
spec
.
component
[
0
].
name
)
spec
.
component
[
0
].
training_beam_size
=
8
# The training run should fail at runtime rather than build time.
with
self
.
assertRaisesRegexp
(
tf
.
errors
.
InvalidArgumentError
,
r
'\[Not implemented.\]'
):
self
.
RunFullTrainingAndInference
(
'simple-parser'
,
master_spec
=
spec
,
expected_num_actions
=
8
,
component_weights
=
[
1
],
expected
=
_LABELED_PARSER_EXPECTED_SENTENCES
)
def
testSimpleParser
(
self
):
def
testSimpleParser
(
self
):
self
.
RunFullTrainingAndInference
(
self
.
RunFullTrainingAndInference
(
'simple-parser'
,
'simple-parser'
,
...
...
syntaxnet/dragnn/python/network_units.py
View file @
277f99c7
# Copyright 2017 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Basic network units used in assembling DRAGNN graphs."""
"""Basic network units used in assembling DRAGNN graphs."""
from
abc
import
ABCMeta
from
abc
import
ABCMeta
...
@@ -88,7 +103,7 @@ class NamedTensor(object):
...
@@ -88,7 +103,7 @@ class NamedTensor(object):
self
.
dim
=
dim
self
.
dim
=
dim
def
add_embeddings
(
channel_id
,
feature_spec
,
seed
):
def
add_embeddings
(
channel_id
,
feature_spec
,
seed
=
None
):
"""Adds a variable for the embedding of a given fixed feature.
"""Adds a variable for the embedding of a given fixed feature.
Supports pre-trained or randomly initialized embeddings In both cases, extra
Supports pre-trained or randomly initialized embeddings In both cases, extra
...
@@ -119,11 +134,14 @@ def add_embeddings(channel_id, feature_spec, seed):
...
@@ -119,11 +134,14 @@ def add_embeddings(channel_id, feature_spec, seed):
if
len
(
feature_spec
.
vocab
.
part
)
>
1
:
if
len
(
feature_spec
.
vocab
.
part
)
>
1
:
raise
RuntimeError
(
'vocab resource contains more than one part:
\n
%s'
,
raise
RuntimeError
(
'vocab resource contains more than one part:
\n
%s'
,
str
(
feature_spec
.
vocab
))
str
(
feature_spec
.
vocab
))
seed1
,
seed2
=
tf
.
get_seed
(
seed
)
embeddings
=
dragnn_ops
.
dragnn_embedding_initializer
(
embeddings
=
dragnn_ops
.
dragnn_embedding_initializer
(
embedding_input
=
feature_spec
.
pretrained_embedding_matrix
.
part
[
0
]
embedding_input
=
feature_spec
.
pretrained_embedding_matrix
.
part
[
0
]
.
file_pattern
,
.
file_pattern
,
vocab
=
feature_spec
.
vocab
.
part
[
0
].
file_pattern
,
vocab
=
feature_spec
.
vocab
.
part
[
0
].
file_pattern
,
scaling_coefficient
=
1.0
)
scaling_coefficient
=
1.0
,
seed
=
seed1
,
seed2
=
seed2
)
return
tf
.
get_variable
(
name
,
initializer
=
tf
.
reshape
(
embeddings
,
shape
))
return
tf
.
get_variable
(
name
,
initializer
=
tf
.
reshape
(
embeddings
,
shape
))
else
:
else
:
return
tf
.
get_variable
(
return
tf
.
get_variable
(
...
@@ -622,7 +640,6 @@ class NetworkUnitInterface(object):
...
@@ -622,7 +640,6 @@ class NetworkUnitInterface(object):
init_layers: optional initial layers.
init_layers: optional initial layers.
init_context_layers: optional initial context layers.
init_context_layers: optional initial context layers.
"""
"""
self
.
_seed
=
component
.
master
.
hyperparams
.
seed
self
.
_component
=
component
self
.
_component
=
component
self
.
_params
=
[]
self
.
_params
=
[]
self
.
_layers
=
init_layers
if
init_layers
else
[]
self
.
_layers
=
init_layers
if
init_layers
else
[]
...
@@ -640,7 +657,7 @@ class NetworkUnitInterface(object):
...
@@ -640,7 +657,7 @@ class NetworkUnitInterface(object):
check
.
Gt
(
spec
.
size
,
0
,
'Invalid fixed feature size'
)
check
.
Gt
(
spec
.
size
,
0
,
'Invalid fixed feature size'
)
if
spec
.
embedding_dim
>
0
:
if
spec
.
embedding_dim
>
0
:
fixed_dim
=
spec
.
embedding_dim
fixed_dim
=
spec
.
embedding_dim
self
.
_params
.
append
(
add_embeddings
(
channel_id
,
spec
,
self
.
_seed
))
self
.
_params
.
append
(
add_embeddings
(
channel_id
,
spec
))
else
:
else
:
fixed_dim
=
1
# assume feature ID extraction; only one ID per step
fixed_dim
=
1
# assume feature ID extraction; only one ID per step
self
.
_fixed_feature_dims
[
spec
.
name
]
=
spec
.
size
*
fixed_dim
self
.
_fixed_feature_dims
[
spec
.
name
]
=
spec
.
size
*
fixed_dim
...
@@ -663,7 +680,7 @@ class NetworkUnitInterface(object):
...
@@ -663,7 +680,7 @@ class NetworkUnitInterface(object):
linked_embeddings_name
(
channel_id
),
linked_embeddings_name
(
channel_id
),
[
source_array_dim
+
1
,
spec
.
embedding_dim
],
[
source_array_dim
+
1
,
spec
.
embedding_dim
],
initializer
=
tf
.
random_normal_initializer
(
initializer
=
tf
.
random_normal_initializer
(
stddev
=
1
/
spec
.
embedding_dim
**
.
5
,
seed
=
self
.
_seed
)))
stddev
=
1
/
spec
.
embedding_dim
**
.
5
)))
self
.
_linked_feature_dims
[
spec
.
name
]
=
spec
.
size
*
spec
.
embedding_dim
self
.
_linked_feature_dims
[
spec
.
name
]
=
spec
.
size
*
spec
.
embedding_dim
else
:
else
:
...
@@ -698,14 +715,12 @@ class NetworkUnitInterface(object):
...
@@ -698,14 +715,12 @@ class NetworkUnitInterface(object):
tf
.
get_variable
(
tf
.
get_variable
(
'attention_weights_pm_0'
,
'attention_weights_pm_0'
,
[
attention_hidden_layer_size
,
hidden_layer_size
],
[
attention_hidden_layer_size
,
hidden_layer_size
],
initializer
=
tf
.
random_normal_initializer
(
initializer
=
tf
.
random_normal_initializer
(
stddev
=
1e-4
)))
stddev
=
1e-4
,
seed
=
self
.
_seed
)))
self
.
_params
.
append
(
self
.
_params
.
append
(
tf
.
get_variable
(
tf
.
get_variable
(
'attention_weights_hm_0'
,
[
hidden_layer_size
,
hidden_layer_size
],
'attention_weights_hm_0'
,
[
hidden_layer_size
,
hidden_layer_size
],
initializer
=
tf
.
random_normal_initializer
(
initializer
=
tf
.
random_normal_initializer
(
stddev
=
1e-4
)))
stddev
=
1e-4
,
seed
=
self
.
_seed
)))
self
.
_params
.
append
(
self
.
_params
.
append
(
tf
.
get_variable
(
tf
.
get_variable
(
...
@@ -721,8 +736,7 @@ class NetworkUnitInterface(object):
...
@@ -721,8 +736,7 @@ class NetworkUnitInterface(object):
tf
.
get_variable
(
tf
.
get_variable
(
'attention_weights_pu'
,
'attention_weights_pu'
,
[
attention_hidden_layer_size
,
component
.
num_actions
],
[
attention_hidden_layer_size
,
component
.
num_actions
],
initializer
=
tf
.
random_normal_initializer
(
initializer
=
tf
.
random_normal_initializer
(
stddev
=
1e-4
)))
stddev
=
1e-4
,
seed
=
self
.
_seed
)))
@
abstractmethod
@
abstractmethod
def
create
(
self
,
def
create
(
self
,
...
@@ -961,8 +975,7 @@ class FeedForwardNetwork(NetworkUnitInterface):
...
@@ -961,8 +975,7 @@ class FeedForwardNetwork(NetworkUnitInterface):
for
index
,
hidden_layer_size
in
enumerate
(
self
.
_hidden_layer_sizes
):
for
index
,
hidden_layer_size
in
enumerate
(
self
.
_hidden_layer_sizes
):
weights
=
tf
.
get_variable
(
weights
=
tf
.
get_variable
(
'weights_%d'
%
index
,
[
last_layer_dim
,
hidden_layer_size
],
'weights_%d'
%
index
,
[
last_layer_dim
,
hidden_layer_size
],
initializer
=
tf
.
random_normal_initializer
(
stddev
=
1e-4
,
initializer
=
tf
.
random_normal_initializer
(
stddev
=
1e-4
))
seed
=
self
.
_seed
))
self
.
_params
.
append
(
weights
)
self
.
_params
.
append
(
weights
)
if
index
>
0
or
self
.
_layer_norm_hidden
is
None
:
if
index
>
0
or
self
.
_layer_norm_hidden
is
None
:
self
.
_params
.
append
(
self
.
_params
.
append
(
...
@@ -988,8 +1001,7 @@ class FeedForwardNetwork(NetworkUnitInterface):
...
@@ -988,8 +1001,7 @@ class FeedForwardNetwork(NetworkUnitInterface):
self
.
_params
.
append
(
self
.
_params
.
append
(
tf
.
get_variable
(
tf
.
get_variable
(
'weights_softmax'
,
[
last_layer_dim
,
component
.
num_actions
],
'weights_softmax'
,
[
last_layer_dim
,
component
.
num_actions
],
initializer
=
tf
.
random_normal_initializer
(
initializer
=
tf
.
random_normal_initializer
(
stddev
=
1e-4
)))
stddev
=
1e-4
,
seed
=
self
.
_seed
)))
self
.
_params
.
append
(
self
.
_params
.
append
(
tf
.
get_variable
(
tf
.
get_variable
(
'bias_softmax'
,
[
component
.
num_actions
],
'bias_softmax'
,
[
component
.
num_actions
],
...
@@ -1106,47 +1118,39 @@ class LSTMNetwork(NetworkUnitInterface):
...
@@ -1106,47 +1118,39 @@ class LSTMNetwork(NetworkUnitInterface):
# e.g. truncated_normal_initializer?
# e.g. truncated_normal_initializer?
self
.
_x2i
=
tf
.
get_variable
(
self
.
_x2i
=
tf
.
get_variable
(
'x2i'
,
[
layer_input_dim
,
self
.
_hidden_layer_sizes
],
'x2i'
,
[
layer_input_dim
,
self
.
_hidden_layer_sizes
],
initializer
=
tf
.
random_normal_initializer
(
stddev
=
1e-4
,
initializer
=
tf
.
random_normal_initializer
(
stddev
=
1e-4
))
seed
=
self
.
_seed
))
self
.
_h2i
=
tf
.
get_variable
(
self
.
_h2i
=
tf
.
get_variable
(
'h2i'
,
[
self
.
_hidden_layer_sizes
,
self
.
_hidden_layer_sizes
],
'h2i'
,
[
self
.
_hidden_layer_sizes
,
self
.
_hidden_layer_sizes
],
initializer
=
tf
.
random_normal_initializer
(
stddev
=
1e-4
,
initializer
=
tf
.
random_normal_initializer
(
stddev
=
1e-4
))
seed
=
self
.
_seed
))
self
.
_c2i
=
tf
.
get_variable
(
self
.
_c2i
=
tf
.
get_variable
(
'c2i'
,
[
self
.
_hidden_layer_sizes
,
self
.
_hidden_layer_sizes
],
'c2i'
,
[
self
.
_hidden_layer_sizes
,
self
.
_hidden_layer_sizes
],
initializer
=
tf
.
random_normal_initializer
(
stddev
=
1e-4
,
initializer
=
tf
.
random_normal_initializer
(
stddev
=
1e-4
))
seed
=
self
.
_seed
))
self
.
_bi
=
tf
.
get_variable
(
self
.
_bi
=
tf
.
get_variable
(
'bi'
,
[
self
.
_hidden_layer_sizes
],
'bi'
,
[
self
.
_hidden_layer_sizes
],
initializer
=
tf
.
random_normal_initializer
(
stddev
=
1e-4
,
seed
=
self
.
_seed
))
initializer
=
tf
.
random_normal_initializer
(
stddev
=
1e-4
))
self
.
_x2o
=
tf
.
get_variable
(
self
.
_x2o
=
tf
.
get_variable
(
'x2o'
,
[
layer_input_dim
,
self
.
_hidden_layer_sizes
],
'x2o'
,
[
layer_input_dim
,
self
.
_hidden_layer_sizes
],
initializer
=
tf
.
random_normal_initializer
(
stddev
=
1e-4
,
initializer
=
tf
.
random_normal_initializer
(
stddev
=
1e-4
))
seed
=
self
.
_seed
))
self
.
_h2o
=
tf
.
get_variable
(
self
.
_h2o
=
tf
.
get_variable
(
'h2o'
,
[
self
.
_hidden_layer_sizes
,
self
.
_hidden_layer_sizes
],
'h2o'
,
[
self
.
_hidden_layer_sizes
,
self
.
_hidden_layer_sizes
],
initializer
=
tf
.
random_normal_initializer
(
stddev
=
1e-4
,
initializer
=
tf
.
random_normal_initializer
(
stddev
=
1e-4
))
seed
=
self
.
_seed
))
self
.
_c2o
=
tf
.
get_variable
(
self
.
_c2o
=
tf
.
get_variable
(
'c2o'
,
[
self
.
_hidden_layer_sizes
,
self
.
_hidden_layer_sizes
],
'c2o'
,
[
self
.
_hidden_layer_sizes
,
self
.
_hidden_layer_sizes
],
initializer
=
tf
.
random_normal_initializer
(
stddev
=
1e-4
,
initializer
=
tf
.
random_normal_initializer
(
stddev
=
1e-4
))
seed
=
self
.
_seed
))
self
.
_bo
=
tf
.
get_variable
(
self
.
_bo
=
tf
.
get_variable
(
'bo'
,
[
self
.
_hidden_layer_sizes
],
'bo'
,
[
self
.
_hidden_layer_sizes
],
initializer
=
tf
.
random_normal_initializer
(
stddev
=
1e-4
,
seed
=
self
.
_seed
))
initializer
=
tf
.
random_normal_initializer
(
stddev
=
1e-4
))
self
.
_x2c
=
tf
.
get_variable
(
self
.
_x2c
=
tf
.
get_variable
(
'x2c'
,
[
layer_input_dim
,
self
.
_hidden_layer_sizes
],
'x2c'
,
[
layer_input_dim
,
self
.
_hidden_layer_sizes
],
initializer
=
tf
.
random_normal_initializer
(
stddev
=
1e-4
,
initializer
=
tf
.
random_normal_initializer
(
stddev
=
1e-4
))
seed
=
self
.
_seed
))
self
.
_h2c
=
tf
.
get_variable
(
self
.
_h2c
=
tf
.
get_variable
(
'h2c'
,
[
self
.
_hidden_layer_sizes
,
self
.
_hidden_layer_sizes
],
'h2c'
,
[
self
.
_hidden_layer_sizes
,
self
.
_hidden_layer_sizes
],
initializer
=
tf
.
random_normal_initializer
(
stddev
=
1e-4
,
initializer
=
tf
.
random_normal_initializer
(
stddev
=
1e-4
))
seed
=
self
.
_seed
))
self
.
_bc
=
tf
.
get_variable
(
self
.
_bc
=
tf
.
get_variable
(
'bc'
,
[
self
.
_hidden_layer_sizes
],
'bc'
,
[
self
.
_hidden_layer_sizes
],
initializer
=
tf
.
random_normal_initializer
(
stddev
=
1e-4
,
seed
=
self
.
_seed
))
initializer
=
tf
.
random_normal_initializer
(
stddev
=
1e-4
))
self
.
_params
.
extend
([
self
.
_params
.
extend
([
self
.
_x2i
,
self
.
_h2i
,
self
.
_c2i
,
self
.
_bi
,
self
.
_x2o
,
self
.
_h2o
,
self
.
_x2i
,
self
.
_h2i
,
self
.
_c2i
,
self
.
_bi
,
self
.
_x2o
,
self
.
_h2o
,
...
@@ -1166,8 +1170,7 @@ class LSTMNetwork(NetworkUnitInterface):
...
@@ -1166,8 +1170,7 @@ class LSTMNetwork(NetworkUnitInterface):
self
.
params
.
append
(
tf
.
get_variable
(
self
.
params
.
append
(
tf
.
get_variable
(
'weights_softmax'
,
[
self
.
_hidden_layer_sizes
,
component
.
num_actions
],
'weights_softmax'
,
[
self
.
_hidden_layer_sizes
,
component
.
num_actions
],
initializer
=
tf
.
random_normal_initializer
(
stddev
=
1e-4
,
initializer
=
tf
.
random_normal_initializer
(
stddev
=
1e-4
)))
seed
=
self
.
_seed
)))
self
.
params
.
append
(
self
.
params
.
append
(
tf
.
get_variable
(
tf
.
get_variable
(
'bias_softmax'
,
[
component
.
num_actions
],
'bias_softmax'
,
[
component
.
num_actions
],
...
@@ -1324,8 +1327,7 @@ class ConvNetwork(NetworkUnitInterface):
...
@@ -1324,8 +1327,7 @@ class ConvNetwork(NetworkUnitInterface):
tf
.
get_variable
(
tf
.
get_variable
(
'weights'
,
'weights'
,
self
.
kernel_shapes
[
i
],
self
.
kernel_shapes
[
i
],
initializer
=
tf
.
random_normal_initializer
(
initializer
=
tf
.
random_normal_initializer
(
stddev
=
1e-4
),
stddev
=
1e-4
,
seed
=
self
.
_seed
),
dtype
=
tf
.
float32
))
dtype
=
tf
.
float32
))
bias_init
=
0.0
if
(
i
==
len
(
self
.
_widths
)
-
1
)
else
0.2
bias_init
=
0.0
if
(
i
==
len
(
self
.
_widths
)
-
1
)
else
0.2
self
.
_biases
.
append
(
self
.
_biases
.
append
(
...
@@ -1473,8 +1475,7 @@ class PairwiseConvNetwork(NetworkUnitInterface):
...
@@ -1473,8 +1475,7 @@ class PairwiseConvNetwork(NetworkUnitInterface):
tf
.
get_variable
(
tf
.
get_variable
(
'weights'
,
'weights'
,
kernel_shape
,
kernel_shape
,
initializer
=
tf
.
random_normal_initializer
(
initializer
=
tf
.
random_normal_initializer
(
stddev
=
1e-4
),
stddev
=
1e-4
,
seed
=
self
.
_seed
),
dtype
=
tf
.
float32
))
dtype
=
tf
.
float32
))
bias_init
=
0.0
if
i
in
self
.
_relu_layers
else
0.2
bias_init
=
0.0
if
i
in
self
.
_relu_layers
else
0.2
self
.
_biases
.
append
(
self
.
_biases
.
append
(
...
...
syntaxnet/dragnn/python/network_units_test.py
View file @
277f99c7
# Copyright 2017 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for network_units."""
"""Tests for network_units."""
...
...
syntaxnet/dragnn/python/render_parse_tree_graphviz.py
View file @
277f99c7
# -*- coding: utf-8 -*-
# -*- coding: utf-8 -*-
# Copyright 2017 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Renders parse trees with Graphviz."""
"""Renders parse trees with Graphviz."""
from
__future__
import
absolute_import
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
division
...
...
syntaxnet/dragnn/python/render_parse_tree_graphviz_test.py
View file @
277f99c7
# Copyright 2017 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for ....dragnn.python.render_parse_tree_graphviz."""
"""Tests for ....dragnn.python.render_parse_tree_graphviz."""
from
__future__
import
absolute_import
from
__future__
import
absolute_import
...
...
syntaxnet/dragnn/python/render_spec_with_graphviz.py
View file @
277f99c7
# -*- coding: utf-8 -*-
# -*- coding: utf-8 -*-
# Copyright 2017 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Renders DRAGNN specs with Graphviz."""
"""Renders DRAGNN specs with Graphviz."""
from
__future__
import
absolute_import
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
division
...
...
syntaxnet/dragnn/python/render_spec_with_graphviz_test.py
View file @
277f99c7
# Copyright 2017 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for render_spec_with_graphviz."""
"""Tests for render_spec_with_graphviz."""
from
__future__
import
absolute_import
from
__future__
import
absolute_import
...
...
syntaxnet/dragnn/python/sentence_io.py
View file @
277f99c7
# Copyright 2017 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utilities for reading and writing sentences in dragnn."""
"""Utilities for reading and writing sentences in dragnn."""
import
tensorflow
as
tf
import
tensorflow
as
tf
from
syntaxnet.ops
import
gen_parser_ops
from
syntaxnet.ops
import
gen_parser_ops
...
...
syntaxnet/dragnn/python/sentence_io_test.py
View file @
277f99c7
# Copyright 2017 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import
os
import
os
import
tensorflow
as
tf
import
tensorflow
as
tf
...
...
syntaxnet/dragnn/python/spec_builder.py
View file @
277f99c7
...
@@ -36,16 +36,20 @@ class ComponentSpecBuilder(object):
...
@@ -36,16 +36,20 @@ class ComponentSpecBuilder(object):
spec: The dragnn.ComponentSpec proto.
spec: The dragnn.ComponentSpec proto.
"""
"""
def
__init__
(
self
,
name
,
builder
=
'DynamicComponentBuilder'
):
def
__init__
(
self
,
name
,
builder
=
'DynamicComponentBuilder'
,
backend
=
'SyntaxNetComponent'
):
"""Initializes the ComponentSpec with some defaults for SyntaxNet.
"""Initializes the ComponentSpec with some defaults for SyntaxNet.
Args:
Args:
name: The name of this Component in the pipeline.
name: The name of this Component in the pipeline.
builder: The component builder type.
builder: The component builder type.
backend: The component backend type.
"""
"""
self
.
spec
=
spec_pb2
.
ComponentSpec
(
self
.
spec
=
spec_pb2
.
ComponentSpec
(
name
=
name
,
name
=
name
,
backend
=
self
.
make_module
(
'SyntaxNetComponent'
),
backend
=
self
.
make_module
(
backend
),
component_builder
=
self
.
make_module
(
builder
))
component_builder
=
self
.
make_module
(
builder
))
def
make_module
(
self
,
name
,
**
kwargs
):
def
make_module
(
self
,
name
,
**
kwargs
):
...
...
syntaxnet/dragnn/python/trainer_lib.py
View file @
277f99c7
# Copyright 2017 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utility functions to build DRAGNN MasterSpecs and schedule model training.
"""Utility functions to build DRAGNN MasterSpecs and schedule model training.
Provides functions to finish a MasterSpec, building required lexicons for it and
Provides functions to finish a MasterSpec, building required lexicons for it and
...
@@ -27,7 +42,7 @@ def calculate_component_accuracies(eval_res_values):
...
@@ -27,7 +42,7 @@ def calculate_component_accuracies(eval_res_values):
]
]
def
_
write_summary
(
summary_writer
,
label
,
value
,
step
):
def
write_summary
(
summary_writer
,
label
,
value
,
step
):
"""Write a summary for a certain evaluation."""
"""Write a summary for a certain evaluation."""
summary
=
Summary
(
value
=
[
Summary
.
Value
(
tag
=
label
,
simple_value
=
float
(
value
))])
summary
=
Summary
(
value
=
[
Summary
.
Value
(
tag
=
label
,
simple_value
=
float
(
value
))])
summary_writer
.
add_summary
(
summary
,
step
)
summary_writer
.
add_summary
(
summary
,
step
)
...
@@ -135,7 +150,7 @@ def run_training(sess, trainers, annotator, evaluator, pretrain_steps,
...
@@ -135,7 +150,7 @@ def run_training(sess, trainers, annotator, evaluator, pretrain_steps,
annotated
=
annotate_dataset
(
sess
,
annotator
,
eval_corpus
)
annotated
=
annotate_dataset
(
sess
,
annotator
,
eval_corpus
)
summaries
=
evaluator
(
eval_gold
,
annotated
)
summaries
=
evaluator
(
eval_gold
,
annotated
)
for
label
,
metric
in
summaries
.
iteritems
():
for
label
,
metric
in
summaries
.
iteritems
():
_
write_summary
(
summary_writer
,
label
,
metric
,
actual_step
+
step
)
write_summary
(
summary_writer
,
label
,
metric
,
actual_step
+
step
)
eval_metric
=
summaries
[
'eval_metric'
]
eval_metric
=
summaries
[
'eval_metric'
]
if
best_eval_metric
<
eval_metric
:
if
best_eval_metric
<
eval_metric
:
tf
.
logging
.
info
(
'Updating best eval to %.2f%%, saving checkpoint.'
,
tf
.
logging
.
info
(
'Updating best eval to %.2f%%, saving checkpoint.'
,
...
...
Prev
1
2
3
4
5
6
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment