Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
a4bb31d0
Commit
a4bb31d0
authored
May 02, 2018
by
Terry Koo
Browse files
Export @195097388.
parent
dea7ecf6
Changes
440
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
207 additions
and
86 deletions
+207
-86
research/syntaxnet/Dockerfile
research/syntaxnet/Dockerfile
+5
-5
research/syntaxnet/README.md
research/syntaxnet/README.md
+6
-6
research/syntaxnet/WORKSPACE
research/syntaxnet/WORKSPACE
+23
-12
research/syntaxnet/docker-devel/Dockerfile-test
research/syntaxnet/docker-devel/Dockerfile-test
+1
-0
research/syntaxnet/docker-devel/Dockerfile-test-base
research/syntaxnet/docker-devel/Dockerfile-test-base
+7
-10
research/syntaxnet/docker-devel/Dockerfile.min
research/syntaxnet/docker-devel/Dockerfile.min
+1
-1
research/syntaxnet/dragnn/components/stateless/BUILD
research/syntaxnet/dragnn/components/stateless/BUILD
+3
-2
research/syntaxnet/dragnn/components/stateless/stateless_component.cc
...taxnet/dragnn/components/stateless/stateless_component.cc
+14
-4
research/syntaxnet/dragnn/components/syntaxnet/BUILD
research/syntaxnet/dragnn/components/syntaxnet/BUILD
+13
-11
research/syntaxnet/dragnn/components/syntaxnet/syntaxnet_component.cc
...taxnet/dragnn/components/syntaxnet/syntaxnet_component.cc
+34
-15
research/syntaxnet/dragnn/components/syntaxnet/syntaxnet_component.h
...ntaxnet/dragnn/components/syntaxnet/syntaxnet_component.h
+13
-1
research/syntaxnet/dragnn/components/syntaxnet/syntaxnet_component_test.cc
...t/dragnn/components/syntaxnet/syntaxnet_component_test.cc
+28
-0
research/syntaxnet/dragnn/conll2017/BUILD
research/syntaxnet/dragnn/conll2017/BUILD
+2
-1
research/syntaxnet/dragnn/conll2017/make_parser_spec.py
research/syntaxnet/dragnn/conll2017/make_parser_spec.py
+1
-1
research/syntaxnet/dragnn/conll2017/sample/zh-segmenter.checkpoint.data-00000-of-00001
...ll2017/sample/zh-segmenter.checkpoint.data-00000-of-00001
+0
-0
research/syntaxnet/dragnn/conll2017/sample/zh-segmenter.checkpoint.index
...net/dragnn/conll2017/sample/zh-segmenter.checkpoint.index
+0
-0
research/syntaxnet/dragnn/conll2017/sample/zh-segmenter.checkpoint.meta
...xnet/dragnn/conll2017/sample/zh-segmenter.checkpoint.meta
+0
-0
research/syntaxnet/dragnn/core/BUILD
research/syntaxnet/dragnn/core/BUILD
+38
-12
research/syntaxnet/dragnn/core/compute_session.h
research/syntaxnet/dragnn/core/compute_session.h
+5
-1
research/syntaxnet/dragnn/core/compute_session_impl.cc
research/syntaxnet/dragnn/core/compute_session_impl.cc
+13
-4
No files found.
research/syntaxnet/Dockerfile
View file @
a4bb31d0
FROM
ubuntu:16.
1
0
FROM
ubuntu:16.0
4
ENV
SYNTAXNETDIR=/opt/tensorflow PATH=$PATH:/root/bin
ENV
SYNTAXNETDIR=/opt/tensorflow PATH=$PATH:/root/bin
...
@@ -57,10 +57,10 @@ RUN python -m pip install \
...
@@ -57,10 +57,10 @@ RUN python -m pip install \
&&
rm
-rf
/root/.cache/pip /tmp/pip
*
&&
rm
-rf
/root/.cache/pip /tmp/pip
*
# Installs Bazel.
# Installs Bazel.
RUN
wget
--quiet
https://github.com/bazelbuild/bazel/releases/download/0.
8
.1/bazel-0.
8
.1-installer-linux-x86_64.sh
\
RUN
wget
--quiet
https://github.com/bazelbuild/bazel/releases/download/0.
11
.1/bazel-0.
11
.1-installer-linux-x86_64.sh
\
&&
chmod
+x bazel-0.
8
.1-installer-linux-x86_64.sh
\
&&
chmod
+x bazel-0.
11
.1-installer-linux-x86_64.sh
\
&&
./bazel-0.
8
.1-installer-linux-x86_64.sh
\
&&
./bazel-0.
11
.1-installer-linux-x86_64.sh
\
&&
rm
./bazel-0.
8
.1-installer-linux-x86_64.sh
&&
rm
./bazel-0.
11
.1-installer-linux-x86_64.sh
COPY
WORKSPACE $SYNTAXNETDIR/syntaxnet/WORKSPACE
COPY
WORKSPACE $SYNTAXNETDIR/syntaxnet/WORKSPACE
COPY
tools/bazel.rc $SYNTAXNETDIR/syntaxnet/tools/bazel.rc
COPY
tools/bazel.rc $SYNTAXNETDIR/syntaxnet/tools/bazel.rc
...
...
research/syntaxnet/README.md
View file @
a4bb31d0
...
@@ -60,10 +60,10 @@ The simplest way to get started with DRAGNN is by loading our Docker container.
...
@@ -60,10 +60,10 @@ The simplest way to get started with DRAGNN is by loading our Docker container.
[
Here
](
g3doc/CLOUD.md
)
is a tutorial for running the DRAGNN container on
[
Here
](
g3doc/CLOUD.md
)
is a tutorial for running the DRAGNN container on
[
GCP
](
https://cloud.google.com
)
(
just
as applicable to your own computer).
[
GCP
](
https://cloud.google.com
)
(
just
as applicable to your own computer).
### Ubuntu 16.
1
0+ binary installation
### Ubuntu 16.0
4
+ binary installation
_This process takes ~5 minutes, but is only compatible with Linux using GNU libc
_This process takes ~5 minutes, but is only compatible with Linux using GNU libc
3.
4.22 and above (e.g. Ubuntu 16.
1
0)._
3.
4.22 and above (e.g. Ubuntu 16.0
4
)._
Binary wheel packages are provided for TensorFlow and SyntaxNet. If you do not
Binary wheel packages are provided for TensorFlow and SyntaxNet. If you do not
need to write new binary TensorFlow ops, these should suffice.
need to write new binary TensorFlow ops, these should suffice.
...
@@ -92,9 +92,9 @@ source. You'll need to install:
...
@@ -92,9 +92,9 @@ source. You'll need to install:
*
python 2.7:
*
python 2.7:
*
Python 3 support is not available yet
*
Python 3 support is not available yet
*
bazel 0.
5.4
:
*
bazel 0.
11.1
:
*
Follow the instructions
[
here
](
http://bazel.build/docs/install.html
)
*
Follow the instructions
[
here
](
http://bazel.build/docs/install.html
)
*
Alternately, Download bazel 0.
5.4
<
.
deb
>
from
*
Alternately, Download bazel 0.
11.1
<
.
deb
>
from
[
https://github.com/bazelbuild/bazel/releases
](
https://github.com/bazelbuild/bazel/releases
)
[
https://github.com/bazelbuild/bazel/releases
](
https://github.com/bazelbuild/bazel/releases
)
for your system configuration.
for your system configuration.
*
Install it using the command: sudo dpkg -i
<
.
deb
file
>
*
Install it using the command: sudo dpkg -i
<
.
deb
file
>
...
@@ -105,14 +105,14 @@ source. You'll need to install:
...
@@ -105,14 +105,14 @@ source. You'll need to install:
*
protocol buffers, with a version supported by TensorFlow:
*
protocol buffers, with a version supported by TensorFlow:
*
check your protobuf version with
`pip freeze | grep protobuf`
*
check your protobuf version with
`pip freeze | grep protobuf`
*
upgrade to a supported version with
`pip install -U protobuf==3.3.0`
*
upgrade to a supported version with
`pip install -U protobuf==3.3.0`
*
autograd, with a version supported by TensorFlow:
*
`pip install -U autograd==1.1.13`
*
mock, the testing package:
*
mock, the testing package:
*
`pip install mock`
*
`pip install mock`
*
asciitree, to draw parse trees on the console for the demo:
*
asciitree, to draw parse trees on the console for the demo:
*
`pip install asciitree`
*
`pip install asciitree`
*
numpy, package for scientific computing:
*
numpy, package for scientific computing:
*
`pip install numpy`
*
`pip install numpy`
*
autograd 1.1.13, for automatic differentiation (not yet compatible with autograd v1.2 rewrite):
*
`pip install autograd==1.1.13`
*
pygraphviz to visualize traces and parse trees:
*
pygraphviz to visualize traces and parse trees:
*
`apt-get install -y graphviz libgraphviz-dev`
*
`apt-get install -y graphviz libgraphviz-dev`
*
`pip install pygraphviz
*
`pip install pygraphviz
...
...
research/syntaxnet/WORKSPACE
View file @
a4bb31d0
local_repository
(
local_repository
(
name
=
"org_tensorflow"
,
name
=
"org_tensorflow"
,
path
=
"tensorflow"
,
path
=
"tensorflow"
,
)
)
# We need to pull in @io_bazel_rules_closure for TensorFlow. Bazel design
# We need to pull in @io_bazel_rules_closure for TensorFlow. Bazel design
...
@@ -9,22 +9,33 @@ local_repository(
...
@@ -9,22 +9,33 @@ local_repository(
# @io_bazel_rules_closure.
# @io_bazel_rules_closure.
http_archive
(
http_archive
(
name
=
"io_bazel_rules_closure"
,
name
=
"io_bazel_rules_closure"
,
sha256
=
"
25f5399f18d8bf9ce435f85c6bbf671ec4820bc4396b3022cc5dc4bc66303609
"
,
sha256
=
"
6691c58a2cd30a86776dd9bb34898b041e37136f2dc7e24cadaeaf599c95c657
"
,
strip_prefix
=
"rules_closure-0
.4.2
"
,
strip_prefix
=
"rules_closure-0
8039ba8ca59f64248bb3b6ae016460fe9c9914f
"
,
urls
=
[
urls
=
[
"http://bazel-mirror.storage.googleapis.com/github.com/bazelbuild/rules_closure/archive/0
.4.2.tar.gz"
,
# 2017-08-30
"http://bazel-mirror.storage.googleapis.com/github.com/bazelbuild/rules_closure/archive/0
8039ba8ca59f64248bb3b6ae016460fe9c9914f.tar.gz"
,
"https://github.com/bazelbuild/rules_closure/archive/0
.4.2
.tar.gz"
,
"https://github.com/bazelbuild/rules_closure/archive/0
8039ba8ca59f64248bb3b6ae016460fe9c9914f
.tar.gz"
,
],
],
)
)
load
(
"@org_tensorflow//tensorflow:workspace.bzl"
,
"tf_workspace"
)
load
(
"@org_tensorflow//tensorflow:workspace.bzl"
,
"tf_workspace"
)
tf_workspace
(
path_prefix
=
""
,
tf_repo_name
=
"org_tensorflow"
)
# Test that Bazel is up-to-date.
tf_workspace
(
load
(
"@org_tensorflow//tensorflow:workspace.bzl"
,
"check_version"
)
path_prefix
=
""
,
check_version
(
"0.4.2"
)
tf_repo_name
=
"org_tensorflow"
,
)
http_archive
(
name
=
"sling"
,
sha256
=
"f1ce597476cb024808ca0a371a01db9dda4e0c58fb34a4f9c4ea91796f437b10"
,
strip_prefix
=
"sling-e3ae9d94eb1d9ee037a851070d54ed2eefaa928a"
,
urls
=
[
"http://bazel-mirror.storage.googleapis.com/github.com/google/sling/archive/e3ae9d94eb1d9ee037a851070d54ed2eefaa928a.tar.gz"
,
"https://github.com/google/sling/archive/e3ae9d94eb1d9ee037a851070d54ed2eefaa928a.tar.gz"
,
],
)
# Used by SLING.
bind
(
bind
(
name
=
"
protobuf
"
,
name
=
"
zlib
"
,
actual
=
"@
protobuf
_archive//:
protobuf
"
,
actual
=
"@
zlib
_archive//:
zlib
"
,
)
)
research/syntaxnet/docker-devel/Dockerfile-test
View file @
a4bb31d0
...
@@ -9,3 +9,4 @@ COPY dragnn $SYNTAXNETDIR/syntaxnet/dragnn
...
@@ -9,3 +9,4 @@ COPY dragnn $SYNTAXNETDIR/syntaxnet/dragnn
COPY syntaxnet $SYNTAXNETDIR/syntaxnet/syntaxnet
COPY syntaxnet $SYNTAXNETDIR/syntaxnet/syntaxnet
COPY third_party $SYNTAXNETDIR/syntaxnet/third_party
COPY third_party $SYNTAXNETDIR/syntaxnet/third_party
COPY util/utf8 $SYNTAXNETDIR/syntaxnet/util/utf8
COPY util/utf8 $SYNTAXNETDIR/syntaxnet/util/utf8
COPY WORKSPACE $SYNTAXNETDIR/syntaxnet/WORKSPACE
research/syntaxnet/docker-devel/Dockerfile-test-base
View file @
a4bb31d0
FROM ubuntu:16.
1
0
FROM ubuntu:16.0
4
ENV SYNTAXNETDIR=/opt/tensorflow PATH=$PATH:/root/bin
ENV SYNTAXNETDIR=/opt/tensorflow PATH=$PATH:/root/bin
...
@@ -57,10 +57,10 @@ RUN python -m pip install \
...
@@ -57,10 +57,10 @@ RUN python -m pip install \
&& rm -rf /root/.cache/pip /tmp/pip*
&& rm -rf /root/.cache/pip /tmp/pip*
# Installs Bazel.
# Installs Bazel.
RUN wget --quiet https://github.com/bazelbuild/bazel/releases/download/0.
5.3
/bazel-0.
5.3
-installer-linux-x86_64.sh \
RUN wget --quiet https://github.com/bazelbuild/bazel/releases/download/0.
11.1
/bazel-0.
11.1
-installer-linux-x86_64.sh \
&& chmod +x bazel-0.
5.3
-installer-linux-x86_64.sh \
&& chmod +x bazel-0.
11.1
-installer-linux-x86_64.sh \
&& JAVA_HOME=/usr/lib/jvm/java-8-openjdk-amd64/ ./bazel-0.
5.3
-installer-linux-x86_64.sh \
&& JAVA_HOME=/usr/lib/jvm/java-8-openjdk-amd64/ ./bazel-0.
11.1
-installer-linux-x86_64.sh \
&& rm ./bazel-0.
5.3
-installer-linux-x86_64.sh
&& rm ./bazel-0.
11.1
-installer-linux-x86_64.sh
COPY WORKSPACE $SYNTAXNETDIR/syntaxnet/WORKSPACE
COPY WORKSPACE $SYNTAXNETDIR/syntaxnet/WORKSPACE
COPY tools/bazel.rc $SYNTAXNETDIR/syntaxnet/tools/bazel.rc
COPY tools/bazel.rc $SYNTAXNETDIR/syntaxnet/tools/bazel.rc
...
@@ -69,12 +69,9 @@ COPY tools/bazel.rc $SYNTAXNETDIR/syntaxnet/tools/bazel.rc
...
@@ -69,12 +69,9 @@ COPY tools/bazel.rc $SYNTAXNETDIR/syntaxnet/tools/bazel.rc
# source. This makes it more convenient to re-compile DRAGNN / SyntaxNet for
# source. This makes it more convenient to re-compile DRAGNN / SyntaxNet for
# development (though not as convenient as the docker-devel scripts).
# development (though not as convenient as the docker-devel scripts).
RUN cd $SYNTAXNETDIR/syntaxnet \
RUN cd $SYNTAXNETDIR/syntaxnet \
&& git clone --branch r1.
3
--recurse-submodules https://github.com/tensorflow/tensorflow \
&& git clone --branch r1.
8
--recurse-submodules https://github.com/tensorflow/tensorflow \
&& cd tensorflow \
&& cd tensorflow \
# This line removes a bad archive target which causes Tensorflow install
&& tensorflow/tools/ci_build/builds/configured CPU \
# to fail.
&& sed -i '\@https://github.com/google/protobuf/archive/0b059a3d8a8f8aa40dde7bea55edca4ec5dfea66.tar.gz@d' tensorflow/workspace.bzl \
&& tensorflow/tools/ci_build/builds/configured CPU \\
&& cd $SYNTAXNETDIR/syntaxnet \
&& cd $SYNTAXNETDIR/syntaxnet \
&& bazel build -c opt @org_tensorflow//tensorflow:tensorflow_py
&& bazel build -c opt @org_tensorflow//tensorflow:tensorflow_py
...
...
research/syntaxnet/docker-devel/Dockerfile.min
View file @
a4bb31d0
...
@@ -3,7 +3,7 @@
...
@@ -3,7 +3,7 @@
#
#
# It might be more efficient to use a minimal distribution, like Alpine. But
# It might be more efficient to use a minimal distribution, like Alpine. But
# the upside of this being popular is that people might already have it.
# the upside of this being popular is that people might already have it.
FROM ubuntu:16.
1
0
FROM ubuntu:16.0
4
ENV SYNTAXNETDIR=/opt/tensorflow PATH=$PATH:/root/bin
ENV SYNTAXNETDIR=/opt/tensorflow PATH=$PATH:/root/bin
...
...
research/syntaxnet/dragnn/components/stateless/BUILD
View file @
a4bb31d0
...
@@ -10,7 +10,8 @@ cc_library(
...
@@ -10,7 +10,8 @@ cc_library(
"//dragnn/core:component_registry"
,
"//dragnn/core:component_registry"
,
"//dragnn/core/interfaces:component"
,
"//dragnn/core/interfaces:component"
,
"//dragnn/core/interfaces:transition_state"
,
"//dragnn/core/interfaces:transition_state"
,
"//dragnn/protos:data_proto"
,
"//dragnn/core/util:label"
,
"//dragnn/protos:data_proto_cc"
,
"//syntaxnet:base"
,
"//syntaxnet:base"
,
],
],
alwayslink
=
1
,
alwayslink
=
1
,
...
@@ -27,7 +28,7 @@ cc_test(
...
@@ -27,7 +28,7 @@ cc_test(
"//dragnn/core/test:mock_transition_state"
,
"//dragnn/core/test:mock_transition_state"
,
"//dragnn/io:sentence_input_batch"
,
"//dragnn/io:sentence_input_batch"
,
"//syntaxnet:base"
,
"//syntaxnet:base"
,
"//syntaxnet:sentence_proto"
,
"//syntaxnet:sentence_proto
_cc
"
,
"//syntaxnet:test_main"
,
"//syntaxnet:test_main"
,
],
],
)
)
research/syntaxnet/dragnn/components/stateless/stateless_component.cc
View file @
a4bb31d0
...
@@ -16,6 +16,7 @@
...
@@ -16,6 +16,7 @@
#include "dragnn/core/component_registry.h"
#include "dragnn/core/component_registry.h"
#include "dragnn/core/interfaces/component.h"
#include "dragnn/core/interfaces/component.h"
#include "dragnn/core/interfaces/transition_state.h"
#include "dragnn/core/interfaces/transition_state.h"
#include "dragnn/core/util/label.h"
#include "dragnn/protos/data.pb.h"
#include "dragnn/protos/data.pb.h"
#include "syntaxnet/base.h"
#include "syntaxnet/base.h"
...
@@ -90,7 +91,8 @@ class StatelessComponent : public Component {
...
@@ -90,7 +91,8 @@ class StatelessComponent : public Component {
void
AdvanceFromOracle
()
override
{
void
AdvanceFromOracle
()
override
{
LOG
(
FATAL
)
<<
"["
<<
name_
<<
"] AdvanceFromOracle not supported"
;
LOG
(
FATAL
)
<<
"["
<<
name_
<<
"] AdvanceFromOracle not supported"
;
}
}
std
::
vector
<
std
::
vector
<
int
>>
GetOracleLabels
()
const
override
{
std
::
vector
<
std
::
vector
<
std
::
vector
<
Label
>>>
GetOracleLabels
()
const
override
{
LOG
(
FATAL
)
<<
"["
<<
name_
<<
"] Method not supported"
;
LOG
(
FATAL
)
<<
"["
<<
name_
<<
"] Method not supported"
;
}
}
int
GetFixedFeatures
(
std
::
function
<
int32
*
(
int
)
>
allocate_indices
,
int
GetFixedFeatures
(
std
::
function
<
int32
*
(
int
)
>
allocate_indices
,
...
@@ -108,7 +110,15 @@ class StatelessComponent : public Component {
...
@@ -108,7 +110,15 @@ class StatelessComponent : public Component {
float
*
embedding_output
)
override
{
float
*
embedding_output
)
override
{
LOG
(
FATAL
)
<<
"["
<<
name_
<<
"] Method not supported"
;
LOG
(
FATAL
)
<<
"["
<<
name_
<<
"] Method not supported"
;
}
}
void
BulkEmbedDenseFixedFeatures
(
const
vector
<
const
float
*>
&
per_channel_embeddings
,
float
*
embedding_output
,
int
embedding_output_size
,
int32
*
offset_array_output
,
int
offset_array_size
)
override
{
LOG
(
FATAL
)
<<
"["
<<
name_
<<
"] Method not supported"
;
}
int
BulkDenseFeatureSize
()
const
override
{
LOG
(
FATAL
)
<<
"Method not supported"
;
}
std
::
vector
<
LinkFeatures
>
GetRawLinkFeatures
(
int
channel_id
)
const
override
{
std
::
vector
<
LinkFeatures
>
GetRawLinkFeatures
(
int
channel_id
)
const
override
{
LOG
(
FATAL
)
<<
"["
<<
name_
<<
"] Method not supported"
;
LOG
(
FATAL
)
<<
"["
<<
name_
<<
"] Method not supported"
;
}
}
...
@@ -118,9 +128,9 @@ class StatelessComponent : public Component {
...
@@ -118,9 +128,9 @@ class StatelessComponent : public Component {
}
}
private:
private:
string
name_
;
// component name
string
name_
;
// component name
int
batch_size_
=
1
;
// number of sentences in current batch
int
batch_size_
=
1
;
// number of sentences in current batch
int
beam_size_
=
1
;
// maximum beam size
int
beam_size_
=
1
;
// maximum beam size
// Parent states passed to InitializeData(), and passed along in GetBeam().
// Parent states passed to InitializeData(), and passed along in GetBeam().
std
::
vector
<
std
::
vector
<
const
TransitionState
*>>
parent_states_
;
std
::
vector
<
std
::
vector
<
const
TransitionState
*>>
parent_states_
;
...
...
research/syntaxnet/dragnn/components/syntaxnet/BUILD
View file @
a4bb31d0
...
@@ -16,18 +16,20 @@ cc_library(
...
@@ -16,18 +16,20 @@ cc_library(
"//dragnn/core:input_batch_cache"
,
"//dragnn/core:input_batch_cache"
,
"//dragnn/core/interfaces:component"
,
"//dragnn/core/interfaces:component"
,
"//dragnn/core/interfaces:transition_state"
,
"//dragnn/core/interfaces:transition_state"
,
"//dragnn/core/util:label"
,
"//dragnn/io:sentence_input_batch"
,
"//dragnn/io:sentence_input_batch"
,
"//dragnn/io:syntaxnet_sentence"
,
"//dragnn/io:syntaxnet_sentence"
,
"//dragnn/protos:data_proto"
,
"//dragnn/protos:data_proto
_cc
"
,
"//dragnn/protos:spec_proto"
,
"//dragnn/protos:spec_proto
_cc
"
,
"//dragnn/protos:trace_proto"
,
"//dragnn/protos:trace_proto
_cc
"
,
"//syntaxnet:base"
,
"//syntaxnet:base"
,
"//syntaxnet:parser_transitions"
,
"//syntaxnet:parser_transitions"
,
"//syntaxnet:registry"
,
"//syntaxnet:registry"
,
"//syntaxnet:sparse_proto"
,
"//syntaxnet:sparse_proto
_cc
"
,
"//syntaxnet:task_context"
,
"//syntaxnet:task_context"
,
"//syntaxnet:task_spec_proto"
,
"//syntaxnet:task_spec_proto
_cc
"
,
"//syntaxnet:utils"
,
"//syntaxnet:utils"
,
"//util/utf8:unicodetext"
,
],
],
alwayslink
=
1
,
alwayslink
=
1
,
)
)
...
@@ -37,7 +39,7 @@ cc_library(
...
@@ -37,7 +39,7 @@ cc_library(
srcs
=
[
"syntaxnet_link_feature_extractor.cc"
],
srcs
=
[
"syntaxnet_link_feature_extractor.cc"
],
hdrs
=
[
"syntaxnet_link_feature_extractor.h"
],
hdrs
=
[
"syntaxnet_link_feature_extractor.h"
],
deps
=
[
deps
=
[
"//dragnn/protos:spec_proto"
,
"//dragnn/protos:spec_proto
_cc
"
,
"//syntaxnet:base"
,
"//syntaxnet:base"
,
"//syntaxnet:embedding_feature_extractor"
,
"//syntaxnet:embedding_feature_extractor"
,
"//syntaxnet:parser_transitions"
,
"//syntaxnet:parser_transitions"
,
...
@@ -53,7 +55,7 @@ cc_library(
...
@@ -53,7 +55,7 @@ cc_library(
"//dragnn/core/interfaces:cloneable_transition_state"
,
"//dragnn/core/interfaces:cloneable_transition_state"
,
"//dragnn/core/interfaces:transition_state"
,
"//dragnn/core/interfaces:transition_state"
,
"//dragnn/io:syntaxnet_sentence"
,
"//dragnn/io:syntaxnet_sentence"
,
"//dragnn/protos:trace_proto"
,
"//dragnn/protos:trace_proto
_cc
"
,
"//syntaxnet:base"
,
"//syntaxnet:base"
,
"//syntaxnet:parser_transitions"
,
"//syntaxnet:parser_transitions"
,
],
],
...
@@ -77,7 +79,7 @@ cc_test(
...
@@ -77,7 +79,7 @@ cc_test(
"//dragnn/core/test:mock_transition_state"
,
"//dragnn/core/test:mock_transition_state"
,
"//dragnn/io:sentence_input_batch"
,
"//dragnn/io:sentence_input_batch"
,
"//syntaxnet:base"
,
"//syntaxnet:base"
,
"//syntaxnet:sentence_proto"
,
"//syntaxnet:sentence_proto
_cc
"
,
"//syntaxnet:test_main"
,
"//syntaxnet:test_main"
,
],
],
)
)
...
@@ -88,7 +90,7 @@ cc_test(
...
@@ -88,7 +90,7 @@ cc_test(
deps
=
[
deps
=
[
":syntaxnet_link_feature_extractor"
,
":syntaxnet_link_feature_extractor"
,
"//dragnn/core/test:generic"
,
"//dragnn/core/test:generic"
,
"//dragnn/protos:spec_proto"
,
"//dragnn/protos:spec_proto
_cc
"
,
"//syntaxnet:task_context"
,
"//syntaxnet:task_context"
,
"//syntaxnet:test_main"
,
"//syntaxnet:test_main"
,
],
],
...
@@ -105,9 +107,9 @@ cc_test(
...
@@ -105,9 +107,9 @@ cc_test(
"//dragnn/core/test:generic"
,
"//dragnn/core/test:generic"
,
"//dragnn/core/test:mock_transition_state"
,
"//dragnn/core/test:mock_transition_state"
,
"//dragnn/io:sentence_input_batch"
,
"//dragnn/io:sentence_input_batch"
,
"//dragnn/protos:spec_proto"
,
"//dragnn/protos:spec_proto
_cc
"
,
"//syntaxnet:base"
,
"//syntaxnet:base"
,
"//syntaxnet:sentence_proto"
,
"//syntaxnet:sentence_proto
_cc
"
,
"//syntaxnet:test_main"
,
"//syntaxnet:test_main"
,
],
],
)
)
research/syntaxnet/dragnn/components/syntaxnet/syntaxnet_component.cc
View file @
a4bb31d0
...
@@ -22,6 +22,7 @@
...
@@ -22,6 +22,7 @@
#include "dragnn/core/input_batch_cache.h"
#include "dragnn/core/input_batch_cache.h"
#include "dragnn/core/interfaces/component.h"
#include "dragnn/core/interfaces/component.h"
#include "dragnn/core/interfaces/transition_state.h"
#include "dragnn/core/interfaces/transition_state.h"
#include "dragnn/core/util/label.h"
#include "dragnn/io/sentence_input_batch.h"
#include "dragnn/io/sentence_input_batch.h"
#include "dragnn/io/syntaxnet_sentence.h"
#include "dragnn/io/syntaxnet_sentence.h"
#include "syntaxnet/parser_state.h"
#include "syntaxnet/parser_state.h"
...
@@ -29,13 +30,12 @@
...
@@ -29,13 +30,12 @@
#include "syntaxnet/task_spec.pb.h"
#include "syntaxnet/task_spec.pb.h"
#include "syntaxnet/utils.h"
#include "syntaxnet/utils.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/logging.h"
#include "util/utf8/unicodetext.h"
namespace
syntaxnet
{
namespace
syntaxnet
{
namespace
dragnn
{
namespace
dragnn
{
using
tensorflow
::
strings
::
StrCat
;
namespace
{
namespace
{
// Returns a new step in a trace based on a ComponentSpec.
// Returns a new step in a trace based on a ComponentSpec.
...
@@ -103,7 +103,7 @@ void SyntaxNetComponent::InitializeComponent(const ComponentSpec &spec) {
...
@@ -103,7 +103,7 @@ void SyntaxNetComponent::InitializeComponent(const ComponentSpec &spec) {
names
.
push_back
(
channel
.
name
());
names
.
push_back
(
channel
.
name
());
fml
.
push_back
(
channel
.
fml
());
fml
.
push_back
(
channel
.
fml
());
predicate_maps
.
push_back
(
channel
.
predicate_map
());
predicate_maps
.
push_back
(
channel
.
predicate_map
());
dims
.
push_back
(
StrCat
(
channel
.
embedding_dim
()));
dims
.
push_back
(
tensorflow
::
strings
::
StrCat
(
channel
.
embedding_dim
()));
}
}
...
@@ -125,7 +125,7 @@ void SyntaxNetComponent::InitializeComponent(const ComponentSpec &spec) {
...
@@ -125,7 +125,7 @@ void SyntaxNetComponent::InitializeComponent(const ComponentSpec &spec) {
for
(
const
LinkedFeatureChannel
&
channel
:
spec
.
linked_feature
())
{
for
(
const
LinkedFeatureChannel
&
channel
:
spec
.
linked_feature
())
{
names
.
push_back
(
channel
.
name
());
names
.
push_back
(
channel
.
name
());
fml
.
push_back
(
channel
.
fml
());
fml
.
push_back
(
channel
.
fml
());
dims
.
push_back
(
StrCat
(
channel
.
embedding_dim
()));
dims
.
push_back
(
tensorflow
::
strings
::
StrCat
(
channel
.
embedding_dim
()));
source_components
.
push_back
(
channel
.
source_component
());
source_components
.
push_back
(
channel
.
source_component
());
source_layers
.
push_back
(
channel
.
source_layer
());
source_layers
.
push_back
(
channel
.
source_layer
());
source_translators
.
push_back
(
channel
.
source_translator
());
source_translators
.
push_back
(
channel
.
source_translator
());
...
@@ -332,6 +332,22 @@ std::function<int(int, int, int)> SyntaxNetComponent::GetStepLookupFunction(
...
@@ -332,6 +332,22 @@ std::function<int(int, int, int)> SyntaxNetComponent::GetStepLookupFunction(
return
-
1
;
return
-
1
;
}
}
};
};
}
else
if
(
method
==
"reverse-char"
)
{
// Reverses the character-level index.
return
[
this
](
int
batch_index
,
int
beam_index
,
int
value
)
{
SyntaxNetTransitionState
*
state
=
batch_
.
at
(
batch_index
)
->
beam_state
(
beam_index
);
const
auto
*
sentence
=
state
->
sentence
()
->
sentence
();
const
string
&
text
=
sentence
->
text
();
const
int
start_byte
=
sentence
->
token
(
0
).
start
();
const
int
end_byte
=
sentence
->
token
(
sentence
->
token_size
()
-
1
).
end
();
UnicodeText
unicode
;
unicode
.
PointToUTF8
(
text
.
data
()
+
start_byte
,
end_byte
-
start_byte
+
1
);
const
int
num_chars
=
distance
(
unicode
.
begin
(),
unicode
.
end
());
const
int
result
=
num_chars
-
value
-
1
;
if
(
result
>=
0
&&
result
<
num_chars
)
return
result
;
return
-
1
;
};
}
else
{
}
else
{
LOG
(
FATAL
)
<<
"Unable to find step lookup function "
<<
method
;
LOG
(
FATAL
)
<<
"Unable to find step lookup function "
<<
method
;
}
}
...
@@ -418,12 +434,12 @@ int SyntaxNetComponent::GetFixedFeatures(
...
@@ -418,12 +434,12 @@ int SyntaxNetComponent::GetFixedFeatures(
const
bool
has_weights
=
f
.
weight_size
()
!=
0
;
const
bool
has_weights
=
f
.
weight_size
()
!=
0
;
for
(
int
i
=
0
;
i
<
f
.
description_size
();
++
i
)
{
for
(
int
i
=
0
;
i
<
f
.
description_size
();
++
i
)
{
if
(
has_weights
)
{
if
(
has_weights
)
{
fixed_features
.
add_value_name
(
StrCat
(
"id: "
,
f
.
id
(
i
),
fixed_features
.
add_value_name
(
tensorflow
::
strings
::
StrCat
(
" name: "
,
f
.
description
(
i
),
"id: "
,
f
.
id
(
i
),
" name: "
,
f
.
description
(
i
),
" weight: "
,
f
.
weight
(
i
)));
" weight: "
,
f
.
weight
(
i
)));
}
else
{
}
else
{
fixed_features
.
add_value_name
(
fixed_features
.
add_value_name
(
tensorflow
::
strings
::
StrCat
(
StrCat
(
"id: "
,
f
.
id
(
i
),
" name: "
,
f
.
description
(
i
)));
"id: "
,
f
.
id
(
i
),
" name: "
,
f
.
description
(
i
)));
}
}
}
}
fixed_features
.
set_feature_name
(
""
);
fixed_features
.
set_feature_name
(
""
);
...
@@ -615,16 +631,19 @@ std::vector<LinkFeatures> SyntaxNetComponent::GetRawLinkFeatures(
...
@@ -615,16 +631,19 @@ std::vector<LinkFeatures> SyntaxNetComponent::GetRawLinkFeatures(
return
features
;
return
features
;
}
}
std
::
vector
<
std
::
vector
<
int
>>
SyntaxNetComponent
::
GetOracleLabels
()
const
{
std
::
vector
<
std
::
vector
<
std
::
vector
<
Label
>>>
std
::
vector
<
std
::
vector
<
int
>>
oracle_labels
;
SyntaxNetComponent
::
GetOracleLabels
()
const
{
for
(
const
auto
&
beam
:
batch_
)
{
std
::
vector
<
std
::
vector
<
std
::
vector
<
Label
>>>
oracle_labels
(
batch_
.
size
());
oracle_labels
.
emplace_back
();
for
(
int
batch_idx
=
0
;
batch_idx
<
batch_
.
size
();
++
batch_idx
)
{
const
auto
&
beam
=
batch_
[
batch_idx
];
std
::
vector
<
std
::
vector
<
Label
>>
&
output_beam
=
oracle_labels
[
batch_idx
];
for
(
int
beam_idx
=
0
;
beam_idx
<
beam
->
size
();
++
beam_idx
)
{
for
(
int
beam_idx
=
0
;
beam_idx
<
beam
->
size
();
++
beam_idx
)
{
// Get the raw link features from the linked feature extractor.
// Get the raw link features from the linked feature extractor.
auto
state
=
beam
->
beam_state
(
beam_idx
);
auto
state
=
beam
->
beam_state
(
beam_idx
);
// Arbitrarily choose the first vector element.
// Arbitrarily choose the first vector element.
oracle_labels
.
back
().
push_back
(
GetOracleVector
(
state
).
front
());
output_beam
.
emplace_back
();
output_beam
.
back
().
emplace_back
(
GetOracleVector
(
state
).
front
());
}
}
}
}
return
oracle_labels
;
return
oracle_labels
;
...
...
research/syntaxnet/dragnn/components/syntaxnet/syntaxnet_component.h
View file @
a4bb31d0
...
@@ -25,6 +25,7 @@
...
@@ -25,6 +25,7 @@
#include "dragnn/core/input_batch_cache.h"
#include "dragnn/core/input_batch_cache.h"
#include "dragnn/core/interfaces/component.h"
#include "dragnn/core/interfaces/component.h"
#include "dragnn/core/interfaces/transition_state.h"
#include "dragnn/core/interfaces/transition_state.h"
#include "dragnn/core/util/label.h"
#include "dragnn/protos/data.pb.h"
#include "dragnn/protos/data.pb.h"
#include "dragnn/protos/spec.pb.h"
#include "dragnn/protos/spec.pb.h"
#include "dragnn/protos/trace.pb.h"
#include "dragnn/protos/trace.pb.h"
...
@@ -113,13 +114,24 @@ class SyntaxNetComponent : public Component {
...
@@ -113,13 +114,24 @@ class SyntaxNetComponent : public Component {
LOG
(
FATAL
)
<<
"Method not supported"
;
LOG
(
FATAL
)
<<
"Method not supported"
;
}
}
void
BulkEmbedDenseFixedFeatures
(
const
vector
<
const
float
*>
&
per_channel_embeddings
,
float
*
embedding_output
,
int
embedding_output_size
,
int32
*
offset_array_output
,
int
offset_array_size
)
override
{
LOG
(
FATAL
)
<<
"Method not supported"
;
}
int
BulkDenseFeatureSize
()
const
override
{
LOG
(
FATAL
)
<<
"Method not supported"
;
}
// Extracts and returns the vector of LinkFeatures for the specified
// Extracts and returns the vector of LinkFeatures for the specified
// channel. Note: these are NOT translated.
// channel. Note: these are NOT translated.
std
::
vector
<
LinkFeatures
>
GetRawLinkFeatures
(
int
channel_id
)
const
override
;
std
::
vector
<
LinkFeatures
>
GetRawLinkFeatures
(
int
channel_id
)
const
override
;
// Returns a vector of oracle labels for each element in the beam and
// Returns a vector of oracle labels for each element in the beam and
// batch.
// batch.
std
::
vector
<
std
::
vector
<
int
>>
GetOracleLabels
()
const
override
;
std
::
vector
<
std
::
vector
<
std
::
vector
<
Label
>
>>
GetOracleLabels
()
const
override
;
// Annotate the underlying data object with the results of this Component's
// Annotate the underlying data object with the results of this Component's
// calculation.
// calculation.
...
...
research/syntaxnet/dragnn/components/syntaxnet/syntaxnet_component_test.cc
View file @
a4bb31d0
...
@@ -40,6 +40,7 @@ namespace dragnn {
...
@@ -40,6 +40,7 @@ namespace dragnn {
namespace
{
namespace
{
const
char
kSentence0
[]
=
R"(
const
char
kSentence0
[]
=
R"(
text: "Sentence 0."
token {
token {
word: "Sentence" start: 0 end: 7 tag: "NN" category: "NOUN" label: "ROOT"
word: "Sentence" start: 0 end: 7 tag: "NN" category: "NOUN" label: "ROOT"
break_level: NO_BREAK
break_level: NO_BREAK
...
@@ -55,6 +56,7 @@ token {
...
@@ -55,6 +56,7 @@ token {
)"
;
)"
;
const
char
kSentence1
[]
=
R"(
const
char
kSentence1
[]
=
R"(
text: "Sentence 1."
token {
token {
word: "Sentence" start: 0 end: 7 tag: "NN" category: "NOUN" label: "ROOT"
word: "Sentence" start: 0 end: 7 tag: "NN" category: "NOUN" label: "ROOT"
break_level: NO_BREAK
break_level: NO_BREAK
...
@@ -70,6 +72,7 @@ token {
...
@@ -70,6 +72,7 @@ token {
)"
;
)"
;
const
char
kLongSentence
[]
=
R"(
const
char
kLongSentence
[]
=
R"(
text: "Sentence 123."
token {
token {
word: "Sentence" start: 0 end: 7 tag: "NN" category: "NOUN" label: "ROOT"
word: "Sentence" start: 0 end: 7 tag: "NN" category: "NOUN" label: "ROOT"
break_level: NO_BREAK
break_level: NO_BREAK
...
@@ -1310,5 +1313,30 @@ TEST_F(SyntaxNetComponentTest, BulkEmbedFixedFeaturesIsNotSupported) {
...
@@ -1310,5 +1313,30 @@ TEST_F(SyntaxNetComponentTest, BulkEmbedFixedFeaturesIsNotSupported) {
"Method not supported"
);
"Method not supported"
);
}
}
TEST_F
(
SyntaxNetComponentTest
,
GetStepLookupFunction
)
{
Sentence
sentence_0
;
TextFormat
::
ParseFromString
(
kSentence0
,
&
sentence_0
);
string
sentence_0_str
;
sentence_0
.
SerializeToString
(
&
sentence_0_str
);
constexpr
int
kBeamSize
=
1
;
auto
test_parser
=
CreateParserWithBeamSize
(
kBeamSize
,
{},
{
sentence_0_str
});
ASSERT_TRUE
(
test_parser
->
IsReady
());
const
auto
reverse_token_lookup
=
test_parser
->
GetStepLookupFunction
(
"reverse-token"
);
const
int
kNumTokens
=
sentence_0
.
token_size
();
for
(
int
i
=
0
;
i
<
kNumTokens
;
++
i
)
{
EXPECT_EQ
(
i
,
reverse_token_lookup
(
0
,
0
,
kNumTokens
-
i
-
1
));
}
const
auto
reverse_char_lookup
=
test_parser
->
GetStepLookupFunction
(
"reverse-char"
);
const
int
kNumChars
=
sentence_0
.
text
().
size
();
// assumes ASCII
for
(
int
i
=
0
;
i
<
kNumChars
;
++
i
)
{
EXPECT_EQ
(
i
,
reverse_char_lookup
(
0
,
0
,
kNumChars
-
i
-
1
));
}
}
}
// namespace dragnn
}
// namespace dragnn
}
// namespace syntaxnet
}
// namespace syntaxnet
research/syntaxnet/dragnn/conll2017/BUILD
View file @
a4bb31d0
...
@@ -2,8 +2,9 @@ py_binary(
...
@@ -2,8 +2,9 @@ py_binary(
name
=
"make_parser_spec"
,
name
=
"make_parser_spec"
,
srcs
=
[
"make_parser_spec.py"
],
srcs
=
[
"make_parser_spec.py"
],
deps
=
[
deps
=
[
"//dragnn/protos:spec_p
y_pb2
"
,
"//dragnn/protos:spec_p
b2_py
"
,
"//dragnn/python:spec_builder"
,
"//dragnn/python:spec_builder"
,
"@absl_py//absl/flags"
,
"@org_tensorflow//tensorflow:tensorflow_py"
,
"@org_tensorflow//tensorflow:tensorflow_py"
,
],
],
)
)
research/syntaxnet/dragnn/conll2017/make_parser_spec.py
View file @
a4bb31d0
...
@@ -14,6 +14,7 @@
...
@@ -14,6 +14,7 @@
# ==============================================================================
# ==============================================================================
"""Construct the spec for the CONLL2017 Parser baseline."""
"""Construct the spec for the CONLL2017 Parser baseline."""
from
absl
import
flags
import
tensorflow
as
tf
import
tensorflow
as
tf
from
tensorflow.python.platform
import
gfile
from
tensorflow.python.platform
import
gfile
...
@@ -21,7 +22,6 @@ from tensorflow.python.platform import gfile
...
@@ -21,7 +22,6 @@ from tensorflow.python.platform import gfile
from
dragnn.protos
import
spec_pb2
from
dragnn.protos
import
spec_pb2
from
dragnn.python
import
spec_builder
from
dragnn.python
import
spec_builder
flags
=
tf
.
app
.
flags
FLAGS
=
flags
.
FLAGS
FLAGS
=
flags
.
FLAGS
flags
.
DEFINE_string
(
'spec_file'
,
'parser_spec.textproto'
,
flags
.
DEFINE_string
(
'spec_file'
,
'parser_spec.textproto'
,
...
...
research/syntaxnet/dragnn/conll2017/sample/zh-segmenter.checkpoint.data-00000-of-00001
View file @
a4bb31d0
No preview for this file type
research/syntaxnet/dragnn/conll2017/sample/zh-segmenter.checkpoint.index
View file @
a4bb31d0
No preview for this file type
research/syntaxnet/dragnn/conll2017/sample/zh-segmenter.checkpoint.meta
View file @
a4bb31d0
No preview for this file type
research/syntaxnet/dragnn/core/BUILD
View file @
a4bb31d0
...
@@ -37,8 +37,9 @@ cc_library(
...
@@ -37,8 +37,9 @@ cc_library(
":input_batch_cache"
,
":input_batch_cache"
,
"//dragnn/components/util:bulk_feature_extractor"
,
"//dragnn/components/util:bulk_feature_extractor"
,
"//dragnn/core/interfaces:component"
,
"//dragnn/core/interfaces:component"
,
"//dragnn/protos:spec_proto"
,
"//dragnn/core/util:label"
,
"//dragnn/protos:trace_proto"
,
"//dragnn/protos:spec_proto_cc"
,
"//dragnn/protos:trace_proto_cc"
,
],
],
)
)
...
@@ -51,9 +52,10 @@ cc_library(
...
@@ -51,9 +52,10 @@ cc_library(
":index_translator"
,
":index_translator"
,
":input_batch_cache"
,
":input_batch_cache"
,
"//dragnn/components/util:bulk_feature_extractor"
,
"//dragnn/components/util:bulk_feature_extractor"
,
"//dragnn/protos:data_proto"
,
"//dragnn/core/util:label"
,
"//dragnn/protos:spec_proto"
,
"//dragnn/protos:data_proto_cc"
,
"//dragnn/protos:trace_proto"
,
"//dragnn/protos:spec_proto_cc"
,
"//dragnn/protos:trace_proto_cc"
,
"//syntaxnet:base"
,
"//syntaxnet:base"
,
"//syntaxnet:registry"
,
"//syntaxnet:registry"
,
],
],
...
@@ -67,7 +69,7 @@ cc_library(
...
@@ -67,7 +69,7 @@ cc_library(
":component_registry"
,
":component_registry"
,
":compute_session"
,
":compute_session"
,
":compute_session_impl"
,
":compute_session_impl"
,
"//dragnn/protos:spec_proto"
,
"//dragnn/protos:spec_proto
_cc
"
,
"//syntaxnet:base"
,
"//syntaxnet:base"
,
],
],
)
)
...
@@ -125,10 +127,13 @@ cc_test(
...
@@ -125,10 +127,13 @@ cc_test(
"//dragnn/components/util:bulk_feature_extractor"
,
"//dragnn/components/util:bulk_feature_extractor"
,
"//dragnn/core/interfaces:component"
,
"//dragnn/core/interfaces:component"
,
"//dragnn/core/interfaces:input_batch"
,
"//dragnn/core/interfaces:input_batch"
,
"//dragnn/core/test:fake_component_base"
,
"//dragnn/core/test:generic"
,
"//dragnn/core/test:generic"
,
"//dragnn/core/test:mock_component"
,
"//dragnn/core/test:mock_component"
,
"//dragnn/core/test:mock_transition_state"
,
"//dragnn/core/test:mock_transition_state"
,
"//dragnn/core/util:label"
,
"//syntaxnet:base"
,
"//syntaxnet:base"
,
"@org_tensorflow//tensorflow/core:test"
,
],
],
)
)
...
@@ -182,14 +187,24 @@ cc_test(
...
@@ -182,14 +187,24 @@ cc_test(
# Tensorflow op kernel BUILD rules.
# Tensorflow op kernel BUILD rules.
load
(
load
(
"
//dragnn
:tensorflow
_ops
.bzl"
,
"
@org_tensorflow//tensorflow
:tensorflow.bzl"
,
"tf_gen_op_libs"
,
"tf_gen_op_libs"
,
"tf_gen_op_wrapper_py"
,
"tf_gen_op_wrapper_py"
,
"tf_kernel_library"
,
"tf_kernel_library"
,
)
)
cc_library
(
name
=
"shape_helpers"
,
hdrs
=
[
"ops/shape_helpers.h"
],
deps
=
[
"//syntaxnet:shape_helpers"
,
"@org_tensorflow//tensorflow/core:framework_headers_lib"
,
],
)
tf_gen_op_libs
(
tf_gen_op_libs
(
op_lib_names
=
[
"dragnn_ops"
],
op_lib_names
=
[
"dragnn_ops"
],
deps
=
[
":shape_helpers"
],
)
)
tf_gen_op_wrapper_py
(
tf_gen_op_wrapper_py
(
...
@@ -199,6 +214,7 @@ tf_gen_op_wrapper_py(
...
@@ -199,6 +214,7 @@ tf_gen_op_wrapper_py(
tf_gen_op_libs
(
tf_gen_op_libs
(
op_lib_names
=
[
"dragnn_bulk_ops"
],
op_lib_names
=
[
"dragnn_bulk_ops"
],
deps
=
[
":shape_helpers"
],
)
)
tf_gen_op_wrapper_py
(
tf_gen_op_wrapper_py
(
...
@@ -231,8 +247,10 @@ cc_library(
...
@@ -231,8 +247,10 @@ cc_library(
":compute_session_op"
,
":compute_session_op"
,
":compute_session_pool"
,
":compute_session_pool"
,
":resource_container"
,
":resource_container"
,
"//dragnn/protos:data_proto"
,
":shape_helpers"
,
"//dragnn/protos:spec_proto"
,
"//dragnn/core/util:label"
,
"//dragnn/protos:data_proto_cc"
,
"//dragnn/protos:spec_proto_cc"
,
"//syntaxnet:base"
,
"//syntaxnet:base"
,
"@org_tensorflow//third_party/eigen3"
,
"@org_tensorflow//third_party/eigen3"
,
],
],
...
@@ -248,6 +266,8 @@ cc_library(
...
@@ -248,6 +266,8 @@ cc_library(
deps
=
[
deps
=
[
":compute_session_op"
,
":compute_session_op"
,
":resource_container"
,
":resource_container"
,
":shape_helpers"
,
"//dragnn/core/util:label"
,
"//syntaxnet:base"
,
"//syntaxnet:base"
,
"@org_tensorflow//third_party/eigen3"
,
"@org_tensorflow//third_party/eigen3"
,
],
],
...
@@ -269,8 +289,10 @@ tf_kernel_library(
...
@@ -269,8 +289,10 @@ tf_kernel_library(
":compute_session_op"
,
":compute_session_op"
,
":compute_session_pool"
,
":compute_session_pool"
,
":resource_container"
,
":resource_container"
,
"//dragnn/protos:data_proto"
,
":shape_helpers"
,
"//dragnn/protos:spec_proto"
,
"//dragnn/core/util:label"
,
"//dragnn/protos:data_proto_cc"
,
"//dragnn/protos:spec_proto_cc"
,
"//syntaxnet:base"
,
"//syntaxnet:base"
,
"@org_tensorflow//third_party/eigen3"
,
"@org_tensorflow//third_party/eigen3"
,
],
],
...
@@ -289,8 +311,10 @@ tf_kernel_library(
...
@@ -289,8 +311,10 @@ tf_kernel_library(
":compute_session_op"
,
":compute_session_op"
,
":compute_session_pool"
,
":compute_session_pool"
,
":resource_container"
,
":resource_container"
,
":shape_helpers"
,
"//dragnn/components/util:bulk_feature_extractor"
,
"//dragnn/components/util:bulk_feature_extractor"
,
"//dragnn/protos:spec_proto"
,
"//dragnn/core/util:label"
,
"//dragnn/protos:spec_proto_cc"
,
"//syntaxnet:base"
,
"//syntaxnet:base"
,
"@org_tensorflow//tensorflow/core:protos_all_cc"
,
"@org_tensorflow//tensorflow/core:protos_all_cc"
,
"@org_tensorflow//third_party/eigen3"
,
"@org_tensorflow//third_party/eigen3"
,
...
@@ -309,6 +333,7 @@ cc_test(
...
@@ -309,6 +333,7 @@ cc_test(
":resource_container"
,
":resource_container"
,
"//dragnn/core/test:generic"
,
"//dragnn/core/test:generic"
,
"//dragnn/core/test:mock_compute_session"
,
"//dragnn/core/test:mock_compute_session"
,
"//dragnn/core/util:label"
,
"//syntaxnet:base"
,
"//syntaxnet:base"
,
"//syntaxnet:test_main"
,
"//syntaxnet:test_main"
,
"@org_tensorflow//tensorflow/core:protos_all_cc"
,
"@org_tensorflow//tensorflow/core:protos_all_cc"
,
...
@@ -327,6 +352,7 @@ cc_test(
...
@@ -327,6 +352,7 @@ cc_test(
":resource_container"
,
":resource_container"
,
"//dragnn/components/util:bulk_feature_extractor"
,
"//dragnn/components/util:bulk_feature_extractor"
,
"//dragnn/core/test:mock_compute_session"
,
"//dragnn/core/test:mock_compute_session"
,
"//dragnn/core/util:label"
,
"//syntaxnet:base"
,
"//syntaxnet:base"
,
"//syntaxnet:test_main"
,
"//syntaxnet:test_main"
,
"@org_tensorflow//tensorflow/core/kernels:ops_testutil"
,
"@org_tensorflow//tensorflow/core/kernels:ops_testutil"
,
...
...
research/syntaxnet/dragnn/core/compute_session.h
View file @
a4bb31d0
...
@@ -22,6 +22,7 @@
...
@@ -22,6 +22,7 @@
#include "dragnn/core/index_translator.h"
#include "dragnn/core/index_translator.h"
#include "dragnn/core/input_batch_cache.h"
#include "dragnn/core/input_batch_cache.h"
#include "dragnn/core/interfaces/component.h"
#include "dragnn/core/interfaces/component.h"
#include "dragnn/core/util/label.h"
#include "dragnn/protos/spec.pb.h"
#include "dragnn/protos/spec.pb.h"
#include "dragnn/protos/trace.pb.h"
#include "dragnn/protos/trace.pb.h"
...
@@ -102,7 +103,7 @@ class ComputeSession {
...
@@ -102,7 +103,7 @@ class ComputeSession {
const
string
&
component_name
,
int
channel_id
)
=
0
;
const
string
&
component_name
,
int
channel_id
)
=
0
;
// Get the oracle labels for the given component.
// Get the oracle labels for the given component.
virtual
std
::
vector
<
std
::
vector
<
int
>>
EmitOracleLabels
(
virtual
std
::
vector
<
std
::
vector
<
std
::
vector
<
Label
>
>>
EmitOracleLabels
(
const
string
&
component_name
)
=
0
;
const
string
&
component_name
)
=
0
;
// Returns true if the given component is terminal.
// Returns true if the given component is terminal.
...
@@ -126,6 +127,9 @@ class ComputeSession {
...
@@ -126,6 +127,9 @@ class ComputeSession {
// bypassing de-serialization.
// bypassing de-serialization.
virtual
void
SetInputBatchCache
(
std
::
unique_ptr
<
InputBatchCache
>
batch
)
=
0
;
virtual
void
SetInputBatchCache
(
std
::
unique_ptr
<
InputBatchCache
>
batch
)
=
0
;
// Returns the current InputBatchCache, or null if there is none.
virtual
InputBatchCache
*
GetInputBatchCache
()
=
0
;
// Resets all components owned by this ComputeSession.
// Resets all components owned by this ComputeSession.
virtual
void
ResetSession
()
=
0
;
virtual
void
ResetSession
()
=
0
;
...
...
research/syntaxnet/dragnn/core/compute_session_impl.cc
View file @
a4bb31d0
...
@@ -18,6 +18,7 @@
...
@@ -18,6 +18,7 @@
#include <algorithm>
#include <algorithm>
#include <utility>
#include <utility>
#include "dragnn/core/util/label.h"
#include "dragnn/protos/data.pb.h"
#include "dragnn/protos/data.pb.h"
#include "dragnn/protos/spec.pb.h"
#include "dragnn/protos/spec.pb.h"
#include "dragnn/protos/trace.pb.h"
#include "dragnn/protos/trace.pb.h"
...
@@ -123,8 +124,12 @@ void ComputeSessionImpl::InitializeComponentData(const string &component_name,
...
@@ -123,8 +124,12 @@ void ComputeSessionImpl::InitializeComponentData(const string &component_name,
VLOG
(
1
)
<<
"Source result found. Using prior initialization vector for "
VLOG
(
1
)
<<
"Source result found. Using prior initialization vector for "
<<
component_name
;
<<
component_name
;
auto
source
=
source_result
->
second
;
auto
source
=
source_result
->
second
;
CHECK
(
source
->
IsTerminal
())
<<
"Source is not terminal for component '"
CHECK
(
source
->
IsTerminal
())
<<
component_name
<<
"'. Exiting."
;
<<
"Source component '"
<<
source
->
Name
()
<<
"' for currently active component '"
<<
component_name
<<
"' is not terminal. "
<<
"Are you using bulk feature extraction with only linked features? "
<<
"If so, consider using the StatelessComponent instead. Exiting."
;
component
->
InitializeData
(
source
->
GetBeam
(),
max_beam_size
,
component
->
InitializeData
(
source
->
GetBeam
(),
max_beam_size
,
input_data_
.
get
());
input_data_
.
get
());
}
}
...
@@ -219,8 +224,8 @@ std::vector<LinkFeatures> ComputeSessionImpl::GetTranslatedLinkFeatures(
...
@@ -219,8 +224,8 @@ std::vector<LinkFeatures> ComputeSessionImpl::GetTranslatedLinkFeatures(
return
features
;
return
features
;
}
}
std
::
vector
<
std
::
vector
<
int
>>
ComputeSessionImpl
::
EmitOracle
Label
s
(
std
::
vector
<
std
::
vector
<
std
::
vector
<
Label
>>>
const
string
&
component_name
)
{
ComputeSessionImpl
::
EmitOracleLabels
(
const
string
&
component_name
)
{
return
GetReadiedComponent
(
component_name
)
->
GetOracleLabels
();
return
GetReadiedComponent
(
component_name
)
->
GetOracleLabels
();
}
}
...
@@ -303,6 +308,10 @@ void ComputeSessionImpl::SetInputBatchCache(
...
@@ -303,6 +308,10 @@ void ComputeSessionImpl::SetInputBatchCache(
input_data_
=
std
::
move
(
batch
);
input_data_
=
std
::
move
(
batch
);
}
}
InputBatchCache
*
ComputeSessionImpl
::
GetInputBatchCache
()
{
return
input_data_
.
get
();
}
void
ComputeSessionImpl
::
ResetSession
()
{
void
ComputeSessionImpl
::
ResetSession
()
{
// Reset all component states.
// Reset all component states.
for
(
auto
&
component_pair
:
components_
)
{
for
(
auto
&
component_pair
:
components_
)
{
...
...
Prev
1
2
3
4
5
…
22
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment