Commit ea3fa4a3 authored by Ivan Bogatyy's avatar Ivan Bogatyy
Browse files

Update DRAGNN, fix some macOS issues

parent b7523ee5
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include <math.h> #include <math.h>
#include <algorithm> #include <algorithm>
#include <utility> #include <utility>
......
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/components/util/bulk_feature_extractor.h" #include "dragnn/components/util/bulk_feature_extractor.h"
#include "dragnn/core/compute_session_pool.h" #include "dragnn/core/compute_session_pool.h"
#include "dragnn/core/resource_container.h" #include "dragnn/core/resource_container.h"
......
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/shape_inference.h" #include "tensorflow/core/framework/shape_inference.h"
...@@ -93,9 +108,13 @@ REGISTER_OP("BulkAdvanceFromPrediction") ...@@ -93,9 +108,13 @@ REGISTER_OP("BulkAdvanceFromPrediction")
.Output("output_handle: string") .Output("output_handle: string")
.Attr("component: string") .Attr("component: string")
.Attr("T: type") .Attr("T: type")
.SetShapeFn([](tensorflow::shape_inference::InferenceContext *context) { .SetShapeFn([](tensorflow::shape_inference::InferenceContext *c) {
auto scores = context->input(1); tensorflow::shape_inference::ShapeHandle handle;
TF_RETURN_IF_ERROR(context->WithRank(scores, 2, &scores)); TF_RETURN_IF_ERROR(c->Merge(c->input(0), c->Vector(2), &handle));
c->set_output(0, handle);
auto scores = c->input(1);
TF_RETURN_IF_ERROR(c->WithRank(scores, 2, &scores));
return tensorflow::Status::OK(); return tensorflow::Status::OK();
}) })
.Doc(R"doc( .Doc(R"doc(
......
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include <memory> #include <memory>
#include <string> #include <string>
#include <vector> #include <vector>
......
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include <functional> #include <functional>
#include <memory> #include <memory>
#include <vector> #include <vector>
......
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op.h"
namespace syntaxnet { namespace syntaxnet {
...@@ -113,6 +128,8 @@ REGISTER_OP("DragnnEmbeddingInitializer") ...@@ -113,6 +128,8 @@ REGISTER_OP("DragnnEmbeddingInitializer")
.Attr("embedding_input: string") .Attr("embedding_input: string")
.Attr("vocab: string") .Attr("vocab: string")
.Attr("scaling_coefficient: float = 1.0") .Attr("scaling_coefficient: float = 1.0")
.Attr("seed: int = 0")
.Attr("seed2: int = 0")
.Doc(R"doc( .Doc(R"doc(
*** PLACEHOLDER OP - FUNCTIONALITY NOT YET IMPLEMENTED *** *** PLACEHOLDER OP - FUNCTIONALITY NOT YET IMPLEMENTED ***
...@@ -122,6 +139,10 @@ embeddings: A tensor containing embeddings from the specified sstable. ...@@ -122,6 +139,10 @@ embeddings: A tensor containing embeddings from the specified sstable.
embedding_input: Path to location with embedding vectors. embedding_input: Path to location with embedding vectors.
vocab: Path to list of keys corresponding to the input. vocab: Path to list of keys corresponding to the input.
scaling_coefficient: A scaling coefficient for the embedding matrix. scaling_coefficient: A scaling coefficient for the embedding matrix.
seed: If either `seed` or `seed2` are set to be non-zero, the random number
generator is seeded by the given seed. Otherwise, it is seeded by a
random seed.
seed2: A second seed to avoid seed collision.
)doc"); )doc");
REGISTER_OP("ExtractFixedFeatures") REGISTER_OP("ExtractFixedFeatures")
......
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#ifndef NLP_SAFT_OPENSOURCE_DRAGNN_CORE_RESOURCE_CONTAINER_H_ #ifndef NLP_SAFT_OPENSOURCE_DRAGNN_CORE_RESOURCE_CONTAINER_H_
#define NLP_SAFT_OPENSOURCE_DRAGNN_CORE_RESOURCE_CONTAINER_H_ #define NLP_SAFT_OPENSOURCE_DRAGNN_CORE_RESOURCE_CONTAINER_H_
......
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
// Tests the methods of ResourceContainer. // Tests the methods of ResourceContainer.
// //
// NOTE(danielandor): For all tests: ResourceContainer is derived from // NOTE(danielandor): For all tests: ResourceContainer is derived from
......
package(default_visibility = ["//visibility:public"]) package(
default_visibility = ["//visibility:public"],
features = ["-layering_check"],
)
cc_library( cc_library(
name = "mock_component", name = "mock_component",
...@@ -13,7 +16,6 @@ cc_library( ...@@ -13,7 +16,6 @@ cc_library(
"//dragnn/protos:spec_proto", "//dragnn/protos:spec_proto",
"//syntaxnet:base", "//syntaxnet:base",
"//syntaxnet:test_main", "//syntaxnet:test_main",
"@org_tensorflow//tensorflow/core:test",
], ],
) )
...@@ -27,7 +29,7 @@ cc_library( ...@@ -27,7 +29,7 @@ cc_library(
"//dragnn/protos:data_proto", "//dragnn/protos:data_proto",
"//dragnn/protos:spec_proto", "//dragnn/protos:spec_proto",
"//syntaxnet:base", "//syntaxnet:base",
"@org_tensorflow//tensorflow/core:test", "//syntaxnet:test_main",
], ],
) )
...@@ -39,7 +41,6 @@ cc_library( ...@@ -39,7 +41,6 @@ cc_library(
"//dragnn/core/interfaces:transition_state", "//dragnn/core/interfaces:transition_state",
"//syntaxnet:base", "//syntaxnet:base",
"//syntaxnet:test_main", "//syntaxnet:test_main",
"@org_tensorflow//tensorflow/core:test",
], ],
) )
...@@ -50,8 +51,6 @@ cc_library( ...@@ -50,8 +51,6 @@ cc_library(
hdrs = ["generic.h"], hdrs = ["generic.h"],
deps = [ deps = [
"//syntaxnet:base", "//syntaxnet:base",
"@org_tensorflow//tensorflow/core:lib", "//syntaxnet:test_main",
"@org_tensorflow//tensorflow/core:test",
"@org_tensorflow//tensorflow/core:testlib",
], ],
) )
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/core/test/generic.h" #include "dragnn/core/test/generic.h"
#include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/lib/io/path.h"
......
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#ifndef NLP_SAFT_OPENSOURCE_DRAGNN_CORE_TEST_GENERIC_H_ #ifndef NLP_SAFT_OPENSOURCE_DRAGNN_CORE_TEST_GENERIC_H_
#define NLP_SAFT_OPENSOURCE_DRAGNN_CORE_TEST_GENERIC_H_ #define NLP_SAFT_OPENSOURCE_DRAGNN_CORE_TEST_GENERIC_H_
......
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#ifndef NLP_SAFT_OPENSOURCE_DRAGNN_CORE_TEST_MOCK_COMPONENT_H_ #ifndef NLP_SAFT_OPENSOURCE_DRAGNN_CORE_TEST_MOCK_COMPONENT_H_
#define NLP_SAFT_OPENSOURCE_DRAGNN_CORE_TEST_MOCK_COMPONENT_H_ #define NLP_SAFT_OPENSOURCE_DRAGNN_CORE_TEST_MOCK_COMPONENT_H_
......
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#ifndef NLP_SAFT_OPENSOURCE_DRAGNN_CORE_TEST_MOCK_COMPUTE_SESSION_H_ #ifndef NLP_SAFT_OPENSOURCE_DRAGNN_CORE_TEST_MOCK_COMPUTE_SESSION_H_
#define NLP_SAFT_OPENSOURCE_DRAGNN_CORE_TEST_MOCK_COMPUTE_SESSION_H_ #define NLP_SAFT_OPENSOURCE_DRAGNN_CORE_TEST_MOCK_COMPUTE_SESSION_H_
......
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#ifndef NLP_SAFT_OPENSOURCE_DRAGNN_CORE_TEST_MOCK_TRANSITION_STATE_H_ #ifndef NLP_SAFT_OPENSOURCE_DRAGNN_CORE_TEST_MOCK_TRANSITION_STATE_H_
#define NLP_SAFT_OPENSOURCE_DRAGNN_CORE_TEST_MOCK_TRANSITION_STATE_H_ #define NLP_SAFT_OPENSOURCE_DRAGNN_CORE_TEST_MOCK_TRANSITION_STATE_H_
......
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/io/sentence_input_batch.h" #include "dragnn/io/sentence_input_batch.h"
#include "syntaxnet/sentence.pb.h" #include "syntaxnet/sentence.pb.h"
......
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#ifndef NLP_SAFT_OPENSOURCE_DRAGNN_IO_SENTENCE_INPUT_BATCH_H_ #ifndef NLP_SAFT_OPENSOURCE_DRAGNN_IO_SENTENCE_INPUT_BATCH_H_
#define NLP_SAFT_OPENSOURCE_DRAGNN_IO_SENTENCE_INPUT_BATCH_H_ #define NLP_SAFT_OPENSOURCE_DRAGNN_IO_SENTENCE_INPUT_BATCH_H_
......
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/io/sentence_input_batch.h" #include "dragnn/io/sentence_input_batch.h"
#include "dragnn/core/test/generic.h" #include "dragnn/core/test/generic.h"
......
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#ifndef NLP_SAFT_OPENSOURCE_DRAGNN_IO_SYNTAXNET_SENTENCE_H_ #ifndef NLP_SAFT_OPENSOURCE_DRAGNN_IO_SYNTAXNET_SENTENCE_H_
#define NLP_SAFT_OPENSOURCE_DRAGNN_IO_SYNTAXNET_SENTENCE_H_ #define NLP_SAFT_OPENSOURCE_DRAGNN_IO_SYNTAXNET_SENTENCE_H_
......
...@@ -9,6 +9,7 @@ cc_binary( ...@@ -9,6 +9,7 @@ cc_binary(
linkshared = 1, linkshared = 1,
linkstatic = 1, linkstatic = 1,
deps = [ deps = [
"//dragnn/components/stateless:stateless_component",
"//dragnn/components/syntaxnet:syntaxnet_component", "//dragnn/components/syntaxnet:syntaxnet_component",
"//dragnn/core:dragnn_bulk_ops_cc", "//dragnn/core:dragnn_bulk_ops_cc",
"//dragnn/core:dragnn_ops_cc", "//dragnn/core:dragnn_ops_cc",
......
# Copyright 2017 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Network units used in the Dozat and Manning (2017) biaffine parser.""" """Network units used in the Dozat and Manning (2017) biaffine parser."""
from __future__ import absolute_import from __future__ import absolute_import
...@@ -68,13 +83,13 @@ class BiaffineDigraphNetwork(network_units.NetworkUnitInterface): ...@@ -68,13 +83,13 @@ class BiaffineDigraphNetwork(network_units.NetworkUnitInterface):
self._weights = [] self._weights = []
self._weights.append(tf.get_variable( self._weights.append(tf.get_variable(
'weights_arc', [self._source_dim, self._target_dim], tf.float32, 'weights_arc', [self._source_dim, self._target_dim], tf.float32,
tf.random_normal_initializer(stddev=1e-4, seed=self._seed))) tf.random_normal_initializer(stddev=1e-4)))
self._weights.append(tf.get_variable( self._weights.append(tf.get_variable(
'weights_source', [self._source_dim], tf.float32, 'weights_source', [self._source_dim], tf.float32,
tf.random_normal_initializer(stddev=1e-4, seed=self._seed))) tf.random_normal_initializer(stddev=1e-4)))
self._weights.append(tf.get_variable( self._weights.append(tf.get_variable(
'root', [self._source_dim], tf.float32, 'root', [self._source_dim], tf.float32,
tf.random_normal_initializer(stddev=1e-4, seed=self._seed))) tf.random_normal_initializer(stddev=1e-4)))
self._params.extend(self._weights) self._params.extend(self._weights)
self._regularized_weights.extend(self._weights) self._regularized_weights.extend(self._weights)
...@@ -178,18 +193,18 @@ class BiaffineLabelNetwork(network_units.NetworkUnitInterface): ...@@ -178,18 +193,18 @@ class BiaffineLabelNetwork(network_units.NetworkUnitInterface):
self._weights = [] self._weights = []
self._weights.append(tf.get_variable( self._weights.append(tf.get_variable(
'weights_pair', [self._num_labels, self._source_dim, self._target_dim], 'weights_pair', [self._num_labels, self._source_dim, self._target_dim],
tf.float32, tf.random_normal_initializer(stddev=1e-4, seed=self._seed))) tf.float32, tf.random_normal_initializer(stddev=1e-4)))
self._weights.append(tf.get_variable( self._weights.append(tf.get_variable(
'weights_source', [self._num_labels, self._source_dim], tf.float32, 'weights_source', [self._num_labels, self._source_dim], tf.float32,
tf.random_normal_initializer(stddev=1e-4, seed=self._seed))) tf.random_normal_initializer(stddev=1e-4)))
self._weights.append(tf.get_variable( self._weights.append(tf.get_variable(
'weights_target', [self._num_labels, self._target_dim], tf.float32, 'weights_target', [self._num_labels, self._target_dim], tf.float32,
tf.random_normal_initializer(stddev=1e-4, seed=self._seed))) tf.random_normal_initializer(stddev=1e-4)))
self._biases = [] self._biases = []
self._biases.append(tf.get_variable( self._biases.append(tf.get_variable(
'biases', [self._num_labels], tf.float32, 'biases', [self._num_labels], tf.float32,
tf.random_normal_initializer(stddev=1e-4, seed=self._seed))) tf.random_normal_initializer(stddev=1e-4)))
self._params.extend(self._weights + self._biases) self._params.extend(self._weights + self._biases)
self._regularized_weights.extend(self._weights) self._regularized_weights.extend(self._weights)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment