Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
edea2b67
Commit
edea2b67
authored
May 11, 2018
by
Terry Koo
Browse files
Remove runtime because reasons.
parent
a4bb31d0
Changes
291
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
0 additions
and
2277 deletions
+0
-2277
research/syntaxnet/dragnn/runtime/xla/xla_compilation.cc
research/syntaxnet/dragnn/runtime/xla/xla_compilation.cc
+0
-166
research/syntaxnet/dragnn/runtime/xla/xla_compilation.h
research/syntaxnet/dragnn/runtime/xla/xla_compilation.h
+0
-78
research/syntaxnet/dragnn/runtime/xla/xla_compilation_test.cc
...arch/syntaxnet/dragnn/runtime/xla/xla_compilation_test.cc
+0
-254
research/syntaxnet/dragnn/runtime/xla/xla_dynamic_component.cc
...rch/syntaxnet/dragnn/runtime/xla/xla_dynamic_component.cc
+0
-120
research/syntaxnet/dragnn/runtime/xla/xla_dynamic_component_base.cc
...yntaxnet/dragnn/runtime/xla/xla_dynamic_component_base.cc
+0
-407
research/syntaxnet/dragnn/runtime/xla/xla_dynamic_component_base.h
...syntaxnet/dragnn/runtime/xla/xla_dynamic_component_base.h
+0
-463
research/syntaxnet/dragnn/runtime/xla/xla_dynamic_component_test.cc
...yntaxnet/dragnn/runtime/xla/xla_dynamic_component_test.cc
+0
-389
research/syntaxnet/dragnn/runtime/xla/xla_extract_config.cc
research/syntaxnet/dragnn/runtime/xla/xla_extract_config.cc
+0
-69
research/syntaxnet/dragnn/runtime/xla/xla_extract_names_from_specs.cc
...taxnet/dragnn/runtime/xla/xla_extract_names_from_specs.cc
+0
-73
research/syntaxnet/dragnn/runtime/xla/xla_graph_utils.cc
research/syntaxnet/dragnn/runtime/xla/xla_graph_utils.cc
+0
-191
research/syntaxnet/dragnn/runtime/xla/xla_graph_utils.h
research/syntaxnet/dragnn/runtime/xla/xla_graph_utils.h
+0
-67
No files found.
Too many changes to show.
To preserve performance only
291 of 291+
files are displayed.
Plain diff
Email patch
research/syntaxnet/dragnn/runtime/xla/xla_compilation.cc
deleted
100644 → 0
View file @
a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/xla/xla_compilation.h"
#include <map>
#include <set>
#include <string>
#include <utility>
#include <vector>
#include "dragnn/protos/export.pb.h"
#include "dragnn/protos/spec.pb.h"
#include "dragnn/runtime/component.h"
#include "dragnn/runtime/trained_model.h"
#include "dragnn/runtime/xla/xla_cell_converter.h"
#include "dragnn/runtime/xla/xla_graph_utils.h"
#include "dragnn/runtime/xla/xla_spec_utils.h"
#include "syntaxnet/base.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/env.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
namespace
{
// Updates the Component subclass in the |component_spec| to an XLA-based
// version. On error, returns non-OK and modifies nothing.
tensorflow
::
Status
XlaCompileComponentSubclass
(
ComponentSpec
*
component_spec
)
{
const
string
subclass
=
GetNormalizedComponentBuilderName
(
*
component_spec
);
if
(
subclass
!=
"DynamicComponent"
)
{
return
tensorflow
::
errors
::
Unimplemented
(
"No XLA-based version of Component subclass '"
,
subclass
,
"'"
);
}
// By convention, the XLA-based version of "FooComponent" should be named
// "XlaFooComponent".
component_spec
->
mutable_component_builder
()
->
set_registered_name
(
tensorflow
::
strings
::
StrCat
(
"Xla"
,
subclass
));
return
tensorflow
::
Status
::
OK
();
}
// Appends the list of component specs in the |master_spec| whose names match
// |component_names| to |matching_components|. On error, returns non-OK.
tensorflow
::
Status
GetMatchingComponentSpecs
(
const
std
::
set
<
string
>
&
component_names
,
MasterSpec
*
master_spec
,
std
::
vector
<
ComponentSpec
*>
*
matching_components
)
{
// Index the components in the |master_spec| by name.
std
::
map
<
string
,
ComponentSpec
*>
components
;
for
(
ComponentSpec
&
component_spec
:
*
master_spec
->
mutable_component
())
{
if
(
!
components
.
emplace
(
component_spec
.
name
(),
&
component_spec
).
second
)
{
return
tensorflow
::
errors
::
InvalidArgument
(
"Duplicate component name: "
,
component_spec
.
name
());
}
}
// Append the components named in the |component_names|.
for
(
const
string
&
component_name
:
component_names
)
{
if
(
components
.
find
(
component_name
)
==
components
.
end
())
{
return
tensorflow
::
errors
::
InvalidArgument
(
"Unknown component name: "
,
component_name
);
}
matching_components
->
push_back
(
components
[
component_name
]);
}
return
tensorflow
::
Status
::
OK
();
}
}
// namespace
tensorflow
::
Status
XlaCompileCells
(
const
string
&
saved_model_dir
,
const
string
&
master_spec_path
,
const
std
::
set
<
string
>
&
component_names
,
const
string
&
model_name
,
const
string
&
output_dir
)
{
MasterSpec
master_spec
;
TF_RETURN_IF_ERROR
(
tensorflow
::
ReadTextProto
(
tensorflow
::
Env
::
Default
(),
master_spec_path
,
&
master_spec
));
std
::
vector
<
ComponentSpec
*>
components
;
TF_RETURN_IF_ERROR
(
GetMatchingComponentSpecs
(
component_names
,
&
master_spec
,
&
components
));
// Returns the path to the output frozen GraphDef file for the
// |component_spec|.
const
auto
get_frozen_graph_def_path
=
[
&
](
const
ComponentSpec
&
component_spec
)
{
return
tensorflow
::
io
::
JoinPath
(
output_dir
,
tensorflow
::
strings
::
StrCat
(
component_spec
.
name
(),
kFrozenGraphDefResourceFileSuffix
));
};
// Perform some changes to the MasterSpec first, to catch issues before
// loading the trained models, which is slow.
for
(
ComponentSpec
*
component_spec
:
components
)
{
// Add a resource for the frozen GraphDef file to each component. The file
// will be created in a second pass, after loading the trained model.
TF_RETURN_IF_ERROR
(
AddFrozenGraphDefResource
(
get_frozen_graph_def_path
(
*
component_spec
),
component_spec
));
// Replace the Component subclass with an XLA-based version.
TF_RETURN_IF_ERROR
(
XlaCompileComponentSubclass
(
component_spec
));
// Set embedding_dim=-1 for all channels.
for
(
auto
&
fixed_channel
:
*
component_spec
->
mutable_fixed_feature
())
{
fixed_channel
.
set_embedding_dim
(
-
1
);
}
for
(
auto
&
linked_channel
:
*
component_spec
->
mutable_linked_feature
())
{
linked_channel
.
set_embedding_dim
(
-
1
);
}
}
// Create output directory which contains the new master spec and
// the frozen graphs.
TF_RETURN_IF_ERROR
(
tensorflow
::
Env
::
Default
()
->
RecursivelyCreateDir
(
output_dir
));
// Convert each component into a frozen GraphDef and write it. Also may
// add a CompilationSpec.
TrainedModel
trained_model
;
TF_RETURN_IF_ERROR
(
trained_model
.
Reset
(
saved_model_dir
));
for
(
ComponentSpec
*
component_spec
:
components
)
{
tensorflow
::
GraphDef
frozen_graph_def
;
CellSubgraphSpec
cell_subgraph_spec
;
TF_RETURN_IF_ERROR
(
XlaCellConverter
::
Convert
(
component_spec
->
name
(),
trained_model
,
&
frozen_graph_def
,
&
cell_subgraph_spec
));
TF_RETURN_IF_ERROR
(
SaveFrozenGraphDef
(
get_frozen_graph_def_path
(
*
component_spec
),
frozen_graph_def
));
if
(
!
model_name
.
empty
())
{
auto
*
compilation_spec
=
component_spec
->
MutableExtension
(
CompilationSpec
::
component_spec_extension
);
compilation_spec
->
set_model_name
(
model_name
);
*
compilation_spec
->
mutable_cell_subgraph_spec
()
=
cell_subgraph_spec
;
}
}
// Write the updated MasterSpec.
TF_RETURN_IF_ERROR
(
tensorflow
::
WriteTextProto
(
tensorflow
::
Env
::
Default
(),
tensorflow
::
io
::
JoinPath
(
output_dir
,
"master-spec"
),
master_spec
));
return
tensorflow
::
Status
::
OK
();
}
}
// namespace runtime
}
// namespace dragnn
}
// namespace syntaxnet
research/syntaxnet/dragnn/runtime/xla/xla_compilation.h
deleted
100644 → 0
View file @
a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
// Utils for modifying pre-trained models to use XLA.
#ifndef DRAGNN_RUNTIME_XLA_XLA_COMPILATION_H_
#define DRAGNN_RUNTIME_XLA_XLA_COMPILATION_H_
#include <set>
#include <string>
#include "syntaxnet/base.h"
#include "tensorflow/core/lib/core/status.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
// Modifies a DRAGNN model to use XLA.
//
// Loads a TF SavedModel from the |saved_model_dir| and a text-format MasterSpec
// from the |master_spec_path|. Converts each component in |component_names|
// into a frozen TF GraphDef (see xla_cell_converter.h) and writes the results
// to the |output_dir| as files "<output_dir>/<component_name>-frozen".
// Modifies the relevant ComponentSpecs in the MasterSpec to use XLA as
// described below, and writes it to "<output_dir>/master-spec".
//
// MasterSpec modifications:
// * Adds a resource to each ComponentSpec that points at the relevant
// frozen GraphDef file in the |output_dir|.
// * Replaces the Component subclass specified in each ComponentSpec with the
// XLA-based equivalent, which should be named "Xla<subclass_name>";
// e.g., XlaDynamicComponent.
// * If |model_name| is non-empty, adds a CompilationSpec extension to each
// ComponentSpec with |model_name| and its corresponding CellSubgraphSpec.
// This is necessary for XLA AOT compilation.
// * Sets FixedFeatureChannel.embedding_dim to -1 in all channels, because
// XLA takes feature IDs as input instead of fixed embedding sums.
// * Sets LinkedFeatureChannel.embedding_dim to -1 in all channels, because
// XLA handles the linked embedding matrix multiplication (if any) and
// always takes the original activation vector as input.
//
// On error, returns non-OK. Possible errors include:
// * Any file I/O or proto parsing error.
// * The MasterSpec has a duplicate component name.
// * One of the |component_names| does not match anything in the MasterSpec.
// * The MasterSpec already has XLA GraphDef resources.
// * One of the components is not supported by XLA.
// * Error raised by XlaCellConverter during conversion.
//
// Side note: This function has a file-path-based API so it can be easily
// wrapped in a stand-alone binary.
tensorflow
::
Status
XlaCompileCells
(
const
string
&
saved_model_dir
,
const
string
&
master_spec_path
,
const
std
::
set
<
string
>
&
component_names
,
const
string
&
model_name
,
const
string
&
output_dir
);
// TODO(googleuser): Add equivalent class for Myelinator.
}
// namespace runtime
}
// namespace dragnn
}
// namespace syntaxnet
#endif // DRAGNN_RUNTIME_XLA_XLA_COMPILATION_H_
research/syntaxnet/dragnn/runtime/xla/xla_compilation_test.cc
deleted
100644 → 0
View file @
a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/xla/xla_compilation.h"
#include <memory>
#include <string>
#include <utility>
#include "dragnn/core/test/generic.h"
#include "dragnn/protos/spec.pb.h"
#include "dragnn/runtime/xla/xla_spec_utils.h"
#include "syntaxnet/base.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/test.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
namespace
{
// Arbitrary bogus path.
constexpr
char
kInvalidPath
[]
=
"path/to/some/invalid/file"
;
// Relative path to a MasterSpec.
constexpr
char
kMasterSpecPath
[]
=
"dragnn/runtime/testdata/rnn_tagger/assets.extra/master_spec"
;
// Relative path to a saved model.
constexpr
char
kSavedModelDir
[]
=
"dragnn/runtime/testdata/rnn_tagger"
;
// Relative path to a directory containing expected output.
constexpr
char
kExpectedOutputDir
[]
=
"dragnn/runtime/xla/testdata/xla_compilation_output"
;
// Local relative path to the expected output directory.
constexpr
char
kLocalOutputDir
[]
=
"dragnn/runtime/xla/testdata/xla_compilation_output"
;
// Returns the set of components in the MasterSpec at |kMasterSpecPath|.
std
::
set
<
string
>
GetComponentNames
()
{
return
{
"rnn"
,
"tagger"
};
}
// Returns the path to a test input denoted by the |relative_path|.
string
GetInput
(
const
string
&
relative_path
)
{
return
tensorflow
::
io
::
JoinPath
(
test
::
GetTestDataPrefix
(),
relative_path
);
}
// Returns a unique output directory for tests.
string
GetUniqueOutputDir
()
{
static
int
counter
=
0
;
return
tensorflow
::
io
::
JoinPath
(
tensorflow
::
testing
::
TmpDir
(),
tensorflow
::
strings
::
StrCat
(
"output_"
,
counter
++
));
}
// Compares the content of the file named |basename| in the |actual_output_dir|
// with the file |testname| in |kExpectedOutputDir|. Can also be modified to
// write the actual file content to |kLocalOutputDir|, for updating test
// expectations.
void
CompareOrRewriteTestData
(
const
string
&
actual_output_dir
,
const
string
&
basename
,
const
string
&
testname
)
{
string
actual_data
;
TF_ASSERT_OK
(
tensorflow
::
ReadFileToString
(
tensorflow
::
Env
::
Default
(),
tensorflow
::
io
::
JoinPath
(
actual_output_dir
,
basename
),
&
actual_data
));
if
(
false
)
{
TF_ASSERT_OK
(
tensorflow
::
WriteStringToFile
(
tensorflow
::
Env
::
Default
(),
tensorflow
::
io
::
JoinPath
(
kLocalOutputDir
,
testname
),
actual_data
));
}
else
{
string
expected_data
;
TF_ASSERT_OK
(
tensorflow
::
ReadFileToString
(
tensorflow
::
Env
::
Default
(),
GetInput
(
tensorflow
::
io
::
JoinPath
(
kExpectedOutputDir
,
testname
)),
&
expected_data
));
// Note: EXPECT_EQ is avoided because printing the diff on failure
// leads to timeouts.
EXPECT_EQ
(
actual_data
,
expected_data
);
EXPECT_TRUE
(
actual_data
==
expected_data
)
<<
"Actual and expected file contents differ for "
<<
basename
<<
"; (actual in "
<<
actual_output_dir
<<
")"
;
}
}
// Compares the content of the file named |basename| in the |actual_output_dir|
// with the file with the same |basename| in |kExpectedOutputDir|. Can also be
// modified to write the actual file content to |kLocalOutputDir|, for updating
// test expectations.
void
CompareOrRewriteTestData
(
const
string
&
actual_output_dir
,
const
string
&
basename
)
{
CompareOrRewriteTestData
(
actual_output_dir
,
basename
,
basename
);
}
// Reads a text-format MasterSpec from the |master_spec_path|, clears resource
// file patterns, and writes it back to the |master_spec_path|. The resource
// file patterns would otherwise cause spurious mismatches.
void
ClearResourceFilePatterns
(
const
string
&
master_spec_path
)
{
MasterSpec
master_spec
;
TF_ASSERT_OK
(
tensorflow
::
ReadTextProto
(
tensorflow
::
Env
::
Default
(),
master_spec_path
,
&
master_spec
));
for
(
ComponentSpec
&
component_spec
:
*
master_spec
.
mutable_component
())
{
for
(
Resource
&
resource
:
*
component_spec
.
mutable_resource
())
{
for
(
Part
&
part
:
*
resource
.
mutable_part
())
{
part
.
clear_file_pattern
();
}
}
}
TF_ASSERT_OK
(
tensorflow
::
WriteTextProto
(
tensorflow
::
Env
::
Default
(),
master_spec_path
,
master_spec
));
}
// Tests that XlaCompileCells() fails if the saved model is invalid.
TEST
(
XlaCompileCellsTest
,
InvalidSavedModel
)
{
EXPECT_FALSE
(
XlaCompileCells
(
kInvalidPath
,
GetInput
(
kMasterSpecPath
),
{},
""
,
GetUniqueOutputDir
())
.
ok
());
}
// Tests that XlaCompileCells() fails if the master spec is invalid.
TEST
(
XlaCompileCellsTest
,
InvalidMasterSpec
)
{
EXPECT_FALSE
(
XlaCompileCells
(
GetInput
(
kSavedModelDir
),
kInvalidPath
,
{},
""
,
GetUniqueOutputDir
())
.
ok
());
}
// Tests that XlaCompileCells() fails if the MasterSpec contains a duplicate
// component.
TEST
(
XlaCompileCellsTest
,
DuplicateComponent
)
{
const
string
kSpec
=
"component { name:'foo' } component { name:'foo' }"
;
const
string
master_spec_path
=
tensorflow
::
io
::
JoinPath
(
tensorflow
::
testing
::
TmpDir
(),
"master-spec-with-duplicate"
);
TF_ASSERT_OK
(
tensorflow
::
WriteStringToFile
(
tensorflow
::
Env
::
Default
(),
master_spec_path
,
kSpec
));
EXPECT_THAT
(
XlaCompileCells
(
GetInput
(
kSavedModelDir
),
master_spec_path
,
{},
""
,
GetUniqueOutputDir
()),
test
::
IsErrorWithSubstr
(
"Duplicate component name: foo"
));
}
// Tests that XlaCompileCells() fails if one of the requested components does
// not appear in the MasterSpec.
TEST
(
XlaCompileCellsTest
,
FilterWithUnknownComponent
)
{
const
string
kSpec
=
"component { name:'foo' } component { name:'bar' }"
;
const
string
master_spec_path
=
tensorflow
::
io
::
JoinPath
(
tensorflow
::
testing
::
TmpDir
(),
"master-spec-foo-bar"
);
TF_ASSERT_OK
(
tensorflow
::
WriteStringToFile
(
tensorflow
::
Env
::
Default
(),
master_spec_path
,
kSpec
));
EXPECT_THAT
(
XlaCompileCells
(
GetInput
(
kSavedModelDir
),
master_spec_path
,
{
"missing"
},
""
,
GetUniqueOutputDir
()),
test
::
IsErrorWithSubstr
(
"Unknown component name: missing"
));
}
// Tests that XlaCompileCells() fails if a component already has a frozen
// GraphDef.
TEST
(
XlaCompileCellsTest
,
AlreadyHasFrozenGraphDef
)
{
const
string
kSpec
=
tensorflow
::
strings
::
StrCat
(
"component { name: 'foo' resource { name: '"
,
kFrozenGraphDefResourceName
,
"' } }"
);
const
string
master_spec_path
=
tensorflow
::
io
::
JoinPath
(
tensorflow
::
testing
::
TmpDir
(),
"master-spec-with-flows"
);
TF_ASSERT_OK
(
tensorflow
::
WriteStringToFile
(
tensorflow
::
Env
::
Default
(),
master_spec_path
,
kSpec
));
EXPECT_THAT
(
XlaCompileCells
(
GetInput
(
kSavedModelDir
),
master_spec_path
,
{
"foo"
},
""
,
GetUniqueOutputDir
()),
test
::
IsErrorWithSubstr
(
"already contains a frozen TF GraphDef resource"
));
}
// Tests that XlaCompileCells() fails on the wrong Component type.
TEST
(
XlaCompileCellsTest
,
WrongComponentType
)
{
const
string
kSpec
=
"component { name: 'foo' component_builder { registered_name: "
"'WrongComponent' } }"
;
const
string
master_spec_path
=
tensorflow
::
io
::
JoinPath
(
tensorflow
::
testing
::
TmpDir
(),
"master-spec"
);
TF_ASSERT_OK
(
tensorflow
::
WriteStringToFile
(
tensorflow
::
Env
::
Default
(),
master_spec_path
,
kSpec
));
EXPECT_THAT
(
XlaCompileCells
(
GetInput
(
kSavedModelDir
),
master_spec_path
,
{
"foo"
},
""
,
GetUniqueOutputDir
()),
test
::
IsErrorWithSubstr
(
"No XLA-based version of Component subclass 'WrongComponent'"
));
}
// Tests that XlaCompileCells() succeeds on the pre-trained inputs and
// reproduces expected outputs.
TEST
(
XlaCompileCellsTest
,
RegressionTest
)
{
const
string
output_dir
=
GetUniqueOutputDir
();
TF_ASSERT_OK
(
XlaCompileCells
(
GetInput
(
kSavedModelDir
),
GetInput
(
kMasterSpecPath
),
GetComponentNames
(),
""
,
output_dir
));
ClearResourceFilePatterns
(
tensorflow
::
io
::
JoinPath
(
output_dir
,
"master-spec"
));
CompareOrRewriteTestData
(
output_dir
,
"master-spec"
);
for
(
const
string
&
component_name
:
GetComponentNames
())
{
const
string
graph_def_basename
=
tensorflow
::
strings
::
StrCat
(
component_name
,
kFrozenGraphDefResourceFileSuffix
);
CompareOrRewriteTestData
(
output_dir
,
graph_def_basename
);
}
}
// Tests that XlaCompileCells() succeeds on the pre-trained inputs and
// reproduces expected outputs.
TEST
(
XlaCompileCellsTest
,
RegressionTestWithModelNameForAot
)
{
const
string
output_dir
=
GetUniqueOutputDir
();
TF_ASSERT_OK
(
XlaCompileCells
(
GetInput
(
kSavedModelDir
),
GetInput
(
kMasterSpecPath
),
GetComponentNames
(),
"model_v1"
,
output_dir
));
ClearResourceFilePatterns
(
tensorflow
::
io
::
JoinPath
(
output_dir
,
"master-spec"
));
CompareOrRewriteTestData
(
output_dir
,
"master-spec"
,
"master-spec-aot"
);
for
(
const
string
&
component_name
:
GetComponentNames
())
{
const
string
graph_def_basename
=
tensorflow
::
strings
::
StrCat
(
component_name
,
kFrozenGraphDefResourceFileSuffix
);
CompareOrRewriteTestData
(
output_dir
,
graph_def_basename
);
}
}
}
// namespace
}
// namespace runtime
}
// namespace dragnn
}
// namespace syntaxnet
research/syntaxnet/dragnn/runtime/xla/xla_dynamic_component.cc
deleted
100644 → 0
View file @
a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include <stddef.h>
#include <memory>
#include <string>
#include "dragnn/core/compute_session.h"
#include "dragnn/protos/export.pb.h"
#include "dragnn/protos/spec.pb.h"
#include "dragnn/protos/trace.pb.h"
#include "dragnn/runtime/component.h"
#include "dragnn/runtime/fixed_embeddings.h"
#include "dragnn/runtime/linked_embeddings.h"
#include "dragnn/runtime/math/types.h"
#include "dragnn/runtime/network_states.h"
#include "dragnn/runtime/session_state.h"
#include "dragnn/runtime/xla/sequence_xla_dynamic_component_mixin.h"
#include "dragnn/runtime/xla/xla_dynamic_component_base.h"
#include "dragnn/runtime/xla/xla_graph_utils.h"
#include "dragnn/runtime/xla/xla_spec_utils.h"
#include "syntaxnet/base.h"
#include "tensorflow/compiler/tf2xla/tf2xla.pb.h"
#include "tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h"
#include "tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/logging.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
namespace
{
// An XLA-based version of DynamicComponent using the XLA JIT API.
//
// It uses the XLA JIT API to compile the graph, and uses the frozen GraphDef
// referred to in the component spec.
class
XlaDynamicComponent
:
public
XlaDynamicComponentBase
{
protected:
// Unlike other specializations, this component will only be active if the
// spec is explicitly modified to support XLA (and frozen graph resources are
// generated).
bool
Supports
(
const
ComponentSpec
&
spec
,
const
string
&
normalized_builder_name
)
const
override
{
return
normalized_builder_name
==
"XlaDynamicComponent"
;
}
bool
PreferredTo
(
const
Component
&
other
)
const
override
{
return
false
;
}
// Gets the frozen GraphDef using the |component_spec| and compiles it.
// The |cell_subgraph_spec| contained within it is filled in. On error,
// returns non-OK.
tensorflow
::
Status
InitializeFromComponentSpec
(
const
ComponentSpec
&
component_spec
,
CellSubgraphSpec
*
cell_subgraph_spec
)
override
;
const
tensorflow
::
XlaCompiledCpuFunction
::
StaticData
&
XlaStaticData
()
const
override
{
if
(
jit_
==
nullptr
)
{
LOG
(
FATAL
)
<<
"XlaStaticData() called before "
"InitializeFromComponentSpec() for component "
<<
name
();
}
return
jit_
->
StaticData
();
}
private:
// Cell that contains the compiled code for this component.
std
::
unique_ptr
<
tensorflow
::
XlaJitCompiledCpuFunction
>
jit_
;
};
tensorflow
::
Status
XlaDynamicComponent
::
InitializeFromComponentSpec
(
const
ComponentSpec
&
component_spec
,
CellSubgraphSpec
*
cell_subgraph_spec
)
{
const
Resource
*
resource
=
nullptr
;
TF_RETURN_IF_ERROR
(
LookupFrozenGraphDefResource
(
component_spec
,
&
resource
));
const
string
&
frozen_graph_def_path
=
resource
->
part
(
0
).
file_pattern
();
tensorflow
::
GraphDef
frozen_graph_def
;
TF_RETURN_IF_ERROR
(
LoadFrozenGraphDef
(
frozen_graph_def_path
,
&
frozen_graph_def
));
// Gets the CellSubgraphSpec from the frozen GraphDef and constructs
// the XLA Config required for compilation.
tensorflow
::
tf2xla
::
Config
xla_config
;
TF_RETURN_IF_ERROR
(
GetSpecAndMakeXlaConfig
(
frozen_graph_def
,
cell_subgraph_spec
,
&
xla_config
));
// Compiles the cell.
TF_ASSIGN_OR_RETURN
(
jit_
,
tensorflow
::
XlaJitCompiledCpuFunction
::
Compile
(
frozen_graph_def
,
xla_config
,
xla
::
ExecutableBuildOptions
()));
return
tensorflow
::
Status
::
OK
();
}
DRAGNN_RUNTIME_REGISTER_COMPONENT
(
XlaDynamicComponent
);
// Sequence-based version of the above.
using
SequenceXlaDynamicComponent
=
SequenceXlaDynamicComponentMixin
<
XlaDynamicComponent
>
;
DRAGNN_RUNTIME_REGISTER_COMPONENT
(
SequenceXlaDynamicComponent
);
}
// namespace
}
// namespace runtime
}
// namespace dragnn
}
// namespace syntaxnet
research/syntaxnet/dragnn/runtime/xla/xla_dynamic_component_base.cc
deleted
100644 → 0
View file @
a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/xla/xla_dynamic_component_base.h"
#include <string.h>
#include <algorithm>
#include "dragnn/protos/export.pb.h"
#include "dragnn/runtime/transition_system_traits.h"
#include "dragnn/runtime/xla/xla_spec_utils.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/error_codes.pb.h"
#include "tensorflow/core/lib/core/errors.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
constexpr
char
XlaDynamicComponentBase
::
kLogitsName
[];
tensorflow
::
Status
XlaDynamicComponentBase
::
Validate
(
const
ComponentSpec
&
component_spec
)
{
if
(
!
component_spec
.
attention_component
().
empty
())
{
return
tensorflow
::
errors
::
Unimplemented
(
"Attention is not supported"
);
}
for
(
const
auto
&
fixed_feature
:
component_spec
.
fixed_feature
())
{
if
(
fixed_feature
.
embedding_dim
()
!=
-
1
)
{
return
tensorflow
::
errors
::
InvalidArgument
(
"XLA requires non-embedded fixed features"
);
}
}
for
(
const
auto
&
linked_feature
:
component_spec
.
linked_feature
())
{
if
(
linked_feature
.
embedding_dim
()
!=
-
1
)
{
return
tensorflow
::
errors
::
InvalidArgument
(
"XLA requires non-multiplied linked features"
);
}
}
return
tensorflow
::
Status
::
OK
();
}
tensorflow
::
Status
XlaDynamicComponentBase
::
ValidateTensor
(
const
string
&
name
,
const
xla
::
PrimitiveType
type
,
int
dimension
,
const
xla
::
Shape
&
shape
,
int
*
elements_out
)
{
if
(
shape
.
element_type
()
!=
type
)
{
return
tensorflow
::
errors
::
InvalidArgument
(
"XLA tensor '"
,
name
,
"' has wrong type "
,
xla
::
PrimitiveType_Name
(
shape
.
element_type
()),
" (expected "
,
xla
::
PrimitiveType_Name
(
type
),
")"
);
}
int
num_nontrivial_dims
=
0
;
int64
elements
=
1
;
for
(
int64
dim
:
shape
.
dimensions
())
{
if
(
dim
>
1
)
{
++
num_nontrivial_dims
;
elements
*=
dim
;
}
}
if
(
num_nontrivial_dims
>
1
)
{
return
tensorflow
::
errors
::
InvalidArgument
(
"XLA tensor has non-vector-like shape: '"
,
name
,
"' "
,
xla
::
ShapeUtil
::
HumanString
(
shape
));
}
if
(
dimension
>=
0
&&
elements
!=
dimension
)
{
return
tensorflow
::
errors
::
InvalidArgument
(
"XLA input shape has the wrong dimension '"
,
name
,
"' "
,
xla
::
ShapeUtil
::
HumanString
(
shape
),
" (expected "
,
dimension
,
")"
);
}
*
elements_out
=
static_cast
<
int
>
(
elements
);
return
tensorflow
::
Status
::
OK
();
}
tensorflow
::
Status
XlaDynamicComponentBase
::
LookupInputVector
(
const
string
&
name
,
const
xla
::
PrimitiveType
type
,
int
dimension
,
const
tensorflow
::
XlaCompiledCpuFunction
&
instance
,
InputHandle
*
input_handle
)
const
{
input_handle
->
index
=
-
1
;
// set to invalid if we error out
const
int
index
=
instance
.
LookupArgIndex
(
name
);
if
(
index
==
-
1
||
index
>=
program_shape_
->
parameters_size
())
{
return
tensorflow
::
errors
::
NotFound
(
"No XLA tensor named '"
,
name
,
"'"
);
}
const
xla
::
Shape
&
shape
=
program_shape_
->
parameters
(
index
);
TF_RETURN_IF_ERROR
(
ValidateTensor
(
name
,
type
,
dimension
,
shape
,
&
input_handle
->
elements
));
input_handle
->
index
=
index
;
return
tensorflow
::
Status
::
OK
();
}
tensorflow
::
Status
XlaDynamicComponentBase
::
LookupOutputVector
(
const
string
&
name
,
const
xla
::
PrimitiveType
type
,
int
dimension
,
const
tensorflow
::
XlaCompiledCpuFunction
&
instance
,
OutputHandle
*
output_handle
)
const
{
output_handle
->
index
=
-
1
;
// set to invalid if we error out
const
int
index
=
instance
.
LookupResultIndex
(
name
);
if
(
index
==
-
1
)
{
return
tensorflow
::
errors
::
NotFound
(
"No XLA tensor named '"
,
name
,
"'"
);
}
const
xla
::
Shape
&
result_shape
=
program_shape_
->
result
();
if
(
result_shape
.
element_type
()
!=
xla
::
TUPLE
)
{
return
tensorflow
::
errors
::
InvalidArgument
(
"XLA output is not a tuple"
);
}
if
(
index
>=
result_shape
.
tuple_shapes_size
())
{
return
tensorflow
::
errors
::
InvalidArgument
(
"Invalid XLA output index: "
,
index
);
}
const
xla
::
Shape
&
shape
=
result_shape
.
tuple_shapes
(
index
);
TF_RETURN_IF_ERROR
(
ValidateTensor
(
name
,
type
,
dimension
,
shape
,
&
output_handle
->
elements
));
output_handle
->
index
=
index
;
output_handle
->
bytes
=
xla
::
ShapeUtil
::
ByteSizeOf
(
shape
);
return
tensorflow
::
Status
::
OK
();
}
tensorflow
::
Status
XlaDynamicComponentBase
::
InitializeInputIds
(
const
tensorflow
::
XlaCompiledCpuFunction
&
instance
)
{
const
int
num_channels
=
fixed_embedding_manager_
.
num_channels
();
input_ids_
.
resize
(
fixed_embedding_manager_
.
num_embeddings
());
for
(
int
channel_id
=
0
;
channel_id
<
num_channels
;
++
channel_id
)
{
DCHECK
(
!
fixed_embedding_manager_
.
is_embedded
(
channel_id
));
const
int
channel_base
=
fixed_embedding_manager_
.
channel_base
(
channel_id
);
const
int
channel_size
=
fixed_embedding_manager_
.
channel_size
(
channel_id
);
for
(
int
index
=
0
;
index
<
channel_size
;
++
index
)
{
InputId
&
input
=
input_ids_
[
channel_base
+
index
];
const
string
name
=
MakeXlaInputFixedFeatureIdName
(
channel_id
,
index
);
TF_RETURN_IF_ERROR
(
LookupInputVector
(
name
,
xla
::
S32
,
1
,
instance
,
&
input
.
id
));
VLOG
(
1
)
<<
"Component '"
<<
name_
<<
"' fixed channel "
<<
channel_id
<<
" index "
<<
index
<<
": Added feature ID"
;
}
}
return
tensorflow
::
Status
::
OK
();
}
tensorflow
::
Status
XlaDynamicComponentBase
::
InitializeInputLinks
(
const
tensorflow
::
XlaCompiledCpuFunction
&
instance
)
{
const
int
num_channels
=
linked_embedding_manager_
.
num_channels
();
input_links_
.
resize
(
num_channels
);
for
(
int
channel_id
=
0
;
channel_id
<
num_channels
;
++
channel_id
)
{
InputLink
&
input
=
input_links_
[
channel_id
];
const
int
dimension
=
linked_embedding_manager_
.
embedding_dim
(
channel_id
);
const
string
activations_name
=
MakeXlaInputLinkedActivationVectorName
(
channel_id
);
const
string
out_of_bounds_name
=
MakeXlaInputLinkedOutOfBoundsIndicatorName
(
channel_id
);
TF_RETURN_IF_ERROR
(
LookupInputVector
(
activations_name
,
xla
::
F32
,
dimension
,
instance
,
&
input
.
activations
));
VLOG
(
1
)
<<
"Component '"
<<
name_
<<
"' linked channel "
<<
channel_id
<<
": Added activations"
;
// Allow NOT_FOUND, for linked embedding channels that don't multiply the
// input activations with an embedding matrix.
const
tensorflow
::
Status
status
=
LookupInputVector
(
out_of_bounds_name
,
xla
::
F32
,
1
,
instance
,
&
input
.
out_of_bounds
);
if
(
status
.
ok
())
{
VLOG
(
1
)
<<
"Component '"
<<
name_
<<
"' linked channel "
<<
channel_id
<<
": Added out-of-bounds indicator for multiplication"
;
}
else
if
(
status
.
code
()
==
tensorflow
::
error
::
NOT_FOUND
)
{
VLOG
(
1
)
<<
"Component '"
<<
name_
<<
"' linked channel "
<<
channel_id
<<
": No out-of-bounds indicator; not multiplied"
;
}
else
{
return
status
;
}
}
return
tensorflow
::
Status
::
OK
();
}
tensorflow
::
Status
XlaDynamicComponentBase
::
InitializeInputRecurrences
(
const
CellSubgraphSpec
&
cell_subgraph_spec
,
const
NetworkStateManager
&
manager
,
const
tensorflow
::
XlaCompiledCpuFunction
&
instance
)
{
for
(
const
auto
&
cell_input
:
cell_subgraph_spec
.
input
())
{
if
(
cell_input
.
type
()
!=
CellSubgraphSpec
::
Input
::
TYPE_RECURRENT
)
continue
;
const
string
&
layer_name
=
cell_input
.
name
();
input_recurrences_
.
emplace_back
();
InputRecurrence
&
input
=
input_recurrences_
.
back
();
const
string
name
=
MakeXlaInputRecurrentLayerName
(
layer_name
);
size_t
dimension
=
1
;
TF_RETURN_IF_ERROR
(
manager
.
LookupLayer
(
name_
,
layer_name
,
&
dimension
,
&
input
.
handle
));
TF_RETURN_IF_ERROR
(
LookupInputVector
(
name
,
xla
::
F32
,
dimension
,
instance
,
&
input
.
previous_output
));
VLOG
(
1
)
<<
"Component '"
<<
name_
<<
"' recurrence '"
<<
layer_name
<<
"': Added link to previous output"
;
}
return
tensorflow
::
Status
::
OK
();
}
tensorflow
::
Status
XlaDynamicComponentBase
::
InitializeOutputLayers
(
const
CellSubgraphSpec
&
cell_subgraph_spec
,
NetworkStateManager
*
manager
,
const
tensorflow
::
XlaCompiledCpuFunction
&
instance
)
{
// Mapping from output tensor name to layer name, for detecting layer aliases.
std
::
map
<
string
,
string
>
tensor_to_layer
;
for
(
const
auto
&
cell_output
:
cell_subgraph_spec
.
output
())
{
const
string
&
layer_name
=
cell_output
.
name
();
output_layers_
.
emplace_back
();
OutputLayer
&
output
=
output_layers_
.
back
();
const
string
name
=
MakeXlaOutputLayerName
(
layer_name
);
// Add a new output layer or create an alias to an existing one.
if
(
tensor_to_layer
.
find
(
cell_output
.
tensor
())
==
tensor_to_layer
.
end
())
{
TF_RETURN_IF_ERROR
(
LookupOutputVector
(
name
,
xla
::
F32
,
-
1
,
instance
,
&
output
.
layer
));
tensor_to_layer
[
cell_output
.
tensor
()]
=
layer_name
;
const
size_t
dimension
=
output
.
layer
.
elements
;
TF_RETURN_IF_ERROR
(
manager
->
AddLayer
(
layer_name
,
dimension
,
&
output
.
handle
));
VLOG
(
1
)
<<
"Component '"
<<
name_
<<
"' output '"
<<
layer_name
<<
"': Added new layer"
;
}
else
{
const
string
&
original_name
=
tensor_to_layer
[
cell_output
.
tensor
()];
output_layers_
.
pop_back
();
// not a "real" output
TF_RETURN_IF_ERROR
(
manager
->
AddLayerAlias
(
layer_name
,
original_name
));
VLOG
(
1
)
<<
"Component '"
<<
name_
<<
"' output '"
<<
layer_name
<<
"': Alias of '"
<<
original_name
<<
"'"
;
}
}
return
tensorflow
::
Status
::
OK
();
}
tensorflow
::
Status
XlaDynamicComponentBase
::
InitializeConstantVectors
()
{
// Find the maximum recurrent layer dimension; the |zeros_| must be this big.
int
max_dimension
=
1
;
// ensure at least one element, for |zero_|
for
(
const
InputRecurrence
&
input
:
input_recurrences_
)
{
max_dimension
=
std
::
max
(
max_dimension
,
input
.
previous_output
.
elements
);
}
// Allocate the backing array and parcel it out into sub-views.
const
std
::
vector
<
size_t
>
sizes
=
{
sizeof
(
float
),
max_dimension
*
sizeof
(
float
)};
array_
.
Reset
(
ComputeTotalBytesWithAlignmentPadding
(
sizes
));
memset
(
array_
.
view
().
data
(),
0
,
array_
.
view
().
size
());
// = 0.0 for float
std
::
vector
<
MutableAlignedView
>
views
;
TF_RETURN_IF_ERROR
(
array_
.
view
().
Split
(
sizes
,
&
views
));
DCHECK_EQ
(
views
.
size
(),
2
);
// Promote to typed vectors.
one_
=
Vector
<
float
>
(
views
[
0
]);
zero_
=
Vector
<
float
>
(
views
[
1
],
1
);
zeros_
=
Vector
<
float
>
(
views
[
1
]);
DCHECK_EQ
(
zero_
.
size
(),
1
);
DCHECK_EQ
(
one_
.
size
(),
1
);
DCHECK_EQ
(
zeros_
.
size
(),
max_dimension
);
// All memory was already zeroed, so only |one_| needs to be initialized.
MutableVector
<
float
>
mutable_one
(
views
[
0
]);
mutable_one
[
0
]
=
1.0
;
return
tensorflow
::
Status
::
OK
();
}
tensorflow
::
Status
XlaDynamicComponentBase
::
MaybeInitializeLogits
(
const
ComponentSpec
&
component_spec
,
const
NetworkStateManager
&
manager
)
{
// Logits are unnecessary when the component is deterministic.
deterministic_
=
TransitionSystemTraits
(
component_spec
).
is_deterministic
;
if
(
deterministic_
)
return
tensorflow
::
Status
::
OK
();
size_t
dimension
=
0
;
TF_RETURN_IF_ERROR
(
manager
.
LookupLayer
(
name_
,
kLogitsName
,
&
dimension
,
&
logits_handle_
));
if
(
dimension
!=
component_spec
.
num_actions
())
{
return
tensorflow
::
errors
::
InvalidArgument
(
"Dimension mismatch between classification logits ("
,
dimension
,
") and ComponentSpec.num_actions ("
,
component_spec
.
num_actions
(),
") in component '"
,
name_
,
"'"
);
}
return
tensorflow
::
Status
::
OK
();
}
tensorflow
::
Status
XlaDynamicComponentBase
::
Initialize
(
const
ComponentSpec
&
component_spec
,
VariableStore
*
variable_store
,
NetworkStateManager
*
network_state_manager
,
ExtensionManager
*
extension_manager
)
{
name_
=
component_spec
.
name
();
TF_RETURN_IF_ERROR
(
Validate
(
component_spec
));
CellSubgraphSpec
cell_subgraph_spec
;
TF_RETURN_IF_ERROR
(
InitializeFromComponentSpec
(
component_spec
,
&
cell_subgraph_spec
));
// Cache the XLA StaticData after InitializeFromComponentSpec().
static_data_
=
&
XlaStaticData
();
// Make a temporary instance to determine shape and input/output indices.
tensorflow
::
XlaCompiledCpuFunction
instance
(
*
static_data_
,
tensorflow
::
XlaCompiledCpuFunction
::
AllocMode
::
RESULTS_PROFILES_AND_TEMPS_ONLY
);
program_shape_
=
instance
.
ProgramShape
();
if
(
program_shape_
==
nullptr
)
{
// Note: this fails when the proto dependency is missing.
return
tensorflow
::
errors
::
InvalidArgument
(
"XLA program shape missing"
);
}
VLOG
(
1
)
<<
"XLA program shape = "
<<
program_shape_
->
DebugString
();
// Configure the inputs and outputs of the XLA cell. As with NetworkUnit
// and NetworkUnitBase, output layers and input features must be initialized
// in a particular order to enable recurrent inputs. Specifically, we must
// populate output layers first, so they are available for recurrent access,
// both by the |input_recurrences_| and the |linked_embedding_manager_|.
TF_RETURN_IF_ERROR
(
InitializeOutputLayers
(
cell_subgraph_spec
,
network_state_manager
,
instance
));
TF_RETURN_IF_ERROR
(
fixed_embedding_manager_
.
Reset
(
component_spec
,
variable_store
,
network_state_manager
));
TF_RETURN_IF_ERROR
(
linked_embedding_manager_
.
Reset
(
component_spec
,
variable_store
,
network_state_manager
));
TF_RETURN_IF_ERROR
(
InitializeInputIds
(
instance
));
TF_RETURN_IF_ERROR
(
InitializeInputLinks
(
instance
));
TF_RETURN_IF_ERROR
(
InitializeInputRecurrences
(
cell_subgraph_spec
,
*
network_state_manager
,
instance
));
TF_RETURN_IF_ERROR
(
InitializeConstantVectors
());
TF_RETURN_IF_ERROR
(
MaybeInitializeLogits
(
component_spec
,
*
network_state_manager
));
extension_manager
->
GetShared
(
&
fixed_embeddings_handle_
);
extension_manager
->
GetShared
(
&
linked_embeddings_handle_
);
extension_manager
->
AddLocal
(
&
instance_handle_
);
return
tensorflow
::
Status
::
OK
();
}
tensorflow
::
Status
XlaDynamicComponentBase
::
Evaluate
(
SessionState
*
session_state
,
ComputeSession
*
compute_session
,
ComponentTrace
*
component_trace
)
const
{
NetworkStates
&
network_states
=
session_state
->
network_states
;
FixedEmbeddings
&
fixed_embeddings
=
session_state
->
extensions
.
Get
(
fixed_embeddings_handle_
);
LinkedEmbeddings
&
linked_embeddings
=
session_state
->
extensions
.
Get
(
linked_embeddings_handle_
);
tensorflow
::
XlaCompiledCpuFunction
&
instance
=
GetInstance
(
session_state
);
for
(
size_t
step_index
=
0
;
!
compute_session
->
IsTerminal
(
name
());
++
step_index
)
{
network_states
.
AddStep
();
TF_RETURN_IF_ERROR
(
fixed_embeddings
.
Reset
(
&
fixed_embedding_manager
(),
network_states
,
compute_session
));
TF_RETURN_IF_ERROR
(
linked_embeddings
.
Reset
(
&
linked_embedding_manager
(),
network_states
,
compute_session
));
// Bind inputs into the |instance|.
BindInputIds
(
fixed_embeddings
,
&
instance
);
BindInputLinks
(
linked_embeddings
,
&
instance
);
BindInputRecurrences
(
step_index
,
network_states
,
&
instance
);
// Invoke the cell in the |instance|.
if
(
!
instance
.
Run
())
{
return
tensorflow
::
errors
::
Internal
(
"Error executing cell for "
,
name
(),
": "
,
instance
.
error_msg
());
}
// Realizes the binding: copy outputs out of the |instance|.
BindOutputLayers
(
step_index
,
network_states
,
&
instance
);
MaybeTrace
(
step_index
,
&
instance
,
component_trace
);
// If the component is deterministic, take the oracle transition instead of
// predicting the next transition using the logits.
if
(
deterministic
())
{
compute_session
->
AdvanceFromOracle
(
name
());
}
else
{
// AddStep() may invalidate the logits (due to reallocation), so the layer
// lookup cannot be hoisted out of this loop.
const
Vector
<
float
>
logits
(
network_states
.
GetLayer
(
logits_handle
()).
row
(
step_index
));
if
(
!
compute_session
->
AdvanceFromPrediction
(
name
(),
logits
.
data
(),
kEvaluateNumItems
,
logits
.
size
()))
{
return
tensorflow
::
errors
::
Internal
(
"Error in ComputeSession::AdvanceFromPrediction()"
);
}
}
}
return
tensorflow
::
Status
::
OK
();
}
}
// namespace runtime
}
// namespace dragnn
}
// namespace syntaxnet
research/syntaxnet/dragnn/runtime/xla/xla_dynamic_component_base.h
deleted
100644 → 0
View file @
a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#ifndef DRAGNN_RUNTIME_XLA_XLA_DYNAMIC_COMPONENT_BASE_H_
#define DRAGNN_RUNTIME_XLA_XLA_DYNAMIC_COMPONENT_BASE_H_
#include <stddef.h>
#include <string.h>
#include <map>
#include <memory>
#include <string>
#include <vector>
#include "dragnn/protos/export.pb.h"
#include "dragnn/protos/spec.pb.h"
#include "dragnn/protos/trace.pb.h"
#include "dragnn/runtime/alignment.h"
#include "dragnn/runtime/component.h"
#include "dragnn/runtime/extensions.h"
#include "dragnn/runtime/fixed_embeddings.h"
#include "dragnn/runtime/linked_embeddings.h"
#include "dragnn/runtime/math/types.h"
#include "dragnn/runtime/network_states.h"
#include "dragnn/runtime/session_state.h"
#include "dragnn/runtime/type_keyed_set.h"
#include "dragnn/runtime/variable_store.h"
#include "syntaxnet/base.h"
#include "tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/framework/allocator.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/types.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
// Base class for XLA-based versions of DynamicComponent.
//
// Roughly, this is a base class for a version of DynamicComponent where the
// per-transition-step computation is performed by a XLA cell instead of a
// NetworkUnit. This class implements Initialize() and Evaluate(). It has
// the most generality w.r.t. input features and links, but suffers from
// ComputeSession overhead. Subclasses which provide specialized logic that
// replaces the generic ComputeSession should override Evaluate().
//
// XLA JIT and AOT versions of this class must supply appropriate versions
// of InitializeFromComponentSpec() and XlaStaticData().
//
// At initialization time, this class creates lists of configuration structs
// that associate each input or output of the XLA cell with an operand that
// the DRAGNN runtime manages. See, e.g., InputId and InitializeInputIds().
//
// At inference time, subclasses can bind the relevant DRAGNN runtime operands
// to the inputs and outputs of the XLA instance (see, e.g., BindInputIds())
// and evaluate the XLA cell. Like DynamicComponent, the cell should be
// evaluated once per transition and the results used to advance the transition
// system state.
//
// Except as noted below, this is a drop-in replacement for DynamicComponent:
// * The name of the logits layer is hard-coded (see kLogitsName).
// * The fixed and linked channels must have embedding_dim=-1, because the fixed
// lookups and linked multiplications are handled within XLA.
//
// The XlaDynamicComponent subclass provides a general-purpose implementation
// of Evaluate(). Other subclasses provide optimized implementations subject to
// restrictions on the possible network configuration.
class
XlaDynamicComponentBase
:
public
Component
{
public:
// Implements Component.
tensorflow
::
Status
Initialize
(
const
ComponentSpec
&
component_spec
,
VariableStore
*
variable_store
,
NetworkStateManager
*
network_state_manager
,
ExtensionManager
*
extension_manager
)
override
;
tensorflow
::
Status
Evaluate
(
SessionState
*
session_state
,
ComputeSession
*
compute_session
,
ComponentTrace
*
component_trace
)
const
override
;
protected:
// Initializes the XLA function using the |component_spec|. When successful,
// the relevant |cell_subgraph_spec| is filled in, and XlaStaticData() is safe
// to call. On error, returns non-OK.
virtual
tensorflow
::
Status
InitializeFromComponentSpec
(
const
ComponentSpec
&
component_spec
,
CellSubgraphSpec
*
cell_subgraph_spec
)
=
0
;
// Returns the StaticData that identifies a specific XLA compiled cell
// function. It is a fatal error to call this before a successful call to
// InitializeFromSpec().
virtual
const
tensorflow
::
XlaCompiledCpuFunction
::
StaticData
&
XlaStaticData
()
const
=
0
;
private:
// Handle to one of the inputs. The |index| is into an array of
// pointers used by XlaCompiledCpuFunction. The input vector has
// the given number of |elements|.
struct
InputHandle
{
int
index
=
-
1
;
int
elements
=
0
;
};
// Handle to one of the outputs. This |index| is into an array of pointers
// into the results tuple used by XlaCompiledCpuFunction.
struct
OutputHandle
{
int
index
=
-
1
;
int
elements
=
0
;
int64
bytes
=
0
;
};
protected:
// Configuration for a fixed feature ID input.
struct
InputId
{
// Tensor to feed with the fixed feature ID.
InputHandle
id
;
};
// Configuration for a linked feature embedding input.
struct
InputLink
{
// Tensor to feed with the linked activation vector.
InputHandle
activations
;
// Tensor to feed with the linked out-of-bounds indicator, or -1 if the
// embedding does not need to be multiplied.
InputHandle
out_of_bounds
;
};
struct
InputRecurrence
{
// Handle of the output layer that is recurrently fed back.
LayerHandle
<
float
>
handle
;
// Tensor to feed with the previous output activation vector.
InputHandle
previous_output
;
};
// Configuration for an output layer.
struct
OutputLayer
{
// Handle of the output layer.
LayerHandle
<
float
>
handle
;
// Tensor that writes to the layer.
OutputHandle
layer
;
};
// Name of the layer containing logits. Unlike DynamicComponent, this class
// does not use the NetworkUnit abstraction and assumes that the logits will
// be stored in this layer.
// TODO(googleuser): Make this configurable, if needed. The logits layer could
// be given a special alias, for example.
static
constexpr
char
kLogitsName
[]
=
"logits"
;
// Points the cell input |handle| in the |instance| at the |vector|.
// Must be called before invoking the cell.
template
<
class
T
>
static
void
BindInput
(
Vector
<
T
>
vector
,
const
InputHandle
&
handle
,
tensorflow
::
XlaCompiledCpuFunction
*
instance
);
// Copies the cell output |handle| in the |instance| to the |vector|.
// Must be called after invoking the cell.
//
// TODO(googleuser): Consider wrapping XlaCompiledCpuFunction along with a map
// from output indices to layer pointers, so this actually binds before the
// call to Run(). Then add a separate function that realizes the output
// binding, copying after Run().
template
<
class
T
>
static
void
BindOutput
(
MutableVector
<
T
>
vector
,
const
OutputHandle
&
handle
,
tensorflow
::
XlaCompiledCpuFunction
*
instance
);
// Binds the feature IDs in the |fixed_embeddings| to the |instance| as
// configured by the |input_ids_|.
void
BindInputIds
(
const
FixedEmbeddings
&
fixed_embeddings
,
tensorflow
::
XlaCompiledCpuFunction
*
instance
)
const
;
// Binds the |embedding| and, if applicable, |is_out_of_bounds| to the
// |input_link| in the |instance|.
void
BindInputLink
(
Vector
<
float
>
embedding
,
bool
is_out_of_bounds
,
const
InputLink
&
input_link
,
tensorflow
::
XlaCompiledCpuFunction
*
instance
)
const
;
// Binds the activation vectors in the |linked_embeddings| to the |instance|
// as configured by the |input_links_|.
void
BindInputLinks
(
const
LinkedEmbeddings
&
linked_embeddings
,
tensorflow
::
XlaCompiledCpuFunction
*
instance
)
const
;
// Binds the output of the step before |step_index| in the |network_states| to
// the |instance| as configured by the |input_recurrences_|.
void
BindInputRecurrences
(
size_t
step_index
,
const
NetworkStates
&
network_states
,
tensorflow
::
XlaCompiledCpuFunction
*
instance
)
const
;
// Binds the output layers for the |step_index| in the |network_states| to the
// |instance| as configured by the |output_layers_|.
void
BindOutputLayers
(
size_t
step_index
,
const
NetworkStates
&
network_states
,
tensorflow
::
XlaCompiledCpuFunction
*
instance
)
const
;
// Returns the reusable XLA instance in the |session_state|.
tensorflow
::
XlaCompiledCpuFunction
&
GetInstance
(
SessionState
*
session_state
)
const
;
// If |component_trace| is non-null, ensures that |step_index|+1 steps exist
// and traces the |instance| in the |step_index|'th step.
void
MaybeTrace
(
size_t
step_index
,
tensorflow
::
XlaCompiledCpuFunction
*
instance
,
ComponentTrace
*
component_trace
)
const
;
// Accessors.
const
string
&
name
()
const
{
return
name_
;
}
const
FixedEmbeddingManager
&
fixed_embedding_manager
()
const
{
return
fixed_embedding_manager_
;
}
const
LinkedEmbeddingManager
&
linked_embedding_manager
()
const
{
return
linked_embedding_manager_
;
}
const
std
::
vector
<
InputId
>
&
input_ids
()
const
{
return
input_ids_
;
}
const
std
::
vector
<
InputLink
>
&
input_links
()
const
{
return
input_links_
;
}
const
std
::
vector
<
InputRecurrence
>
&
input_recurrences
()
const
{
return
input_recurrences_
;
}
const
std
::
vector
<
OutputLayer
>
&
output_layers
()
const
{
return
output_layers_
;
}
bool
deterministic
()
const
{
return
deterministic_
;
}
LayerHandle
<
float
>
logits_handle
()
const
{
return
logits_handle_
;
}
private:
// Forbid batches and beams.
static
constexpr
int
kEvaluateNumItems
=
1
;
// Required alignment of pointers to input tensors.
static
constexpr
size_t
kXlaByteAlignment
=
tensorflow
::
Allocator
::
kAllocatorAlignment
;
// Returns non-OK if the |component_spec| specifies any unsupported settings.
// This includes both settings that are not yet implemented and those that are
// fundamentally incompatible with this class.
static
tensorflow
::
Status
Validate
(
const
ComponentSpec
&
component_spec
);
// Returns non-OK if the tensor called |name| isn't compatible with |type| or
// has an invalid |shape| given |dimension| for use as an input or output.
// If OK, |elements_out| contains the number of elements in the vector.
static
tensorflow
::
Status
ValidateTensor
(
const
string
&
name
,
const
xla
::
PrimitiveType
type
,
int
dimension
,
const
xla
::
Shape
&
shape
,
int
*
elements_out
);
// Points the |input_handle| or |output_handle| at the variable in the
// |network_| named |name|, which must have a vector-like shape (i.e., having
// at most one dimension > 1) and must match the |type|. The |instance| is
// used to determine the mapping from |name| to the handle. If the |dimension|
// is >= 0, then the |vector| must be the same size.
// On error, returns non-OK and sets |vector| to nullptr.
// Returns NOT_FOUND iff the |name| does not name a variable.
tensorflow
::
Status
LookupInputVector
(
const
string
&
name
,
const
xla
::
PrimitiveType
type
,
int
dimension
,
const
tensorflow
::
XlaCompiledCpuFunction
&
instance
,
InputHandle
*
input_handle
)
const
;
tensorflow
::
Status
LookupOutputVector
(
const
string
&
name
,
const
xla
::
PrimitiveType
type
,
int
dimension
,
const
tensorflow
::
XlaCompiledCpuFunction
&
instance
,
OutputHandle
*
output_handle
)
const
;
// Initializes the |input_ids_| based on the |fixed_embedding_manager_| and
// |network_|. On error, returns non-OK.
tensorflow
::
Status
InitializeInputIds
(
const
tensorflow
::
XlaCompiledCpuFunction
&
instance
);
// Initializes the |input_links_| based on the |linked_embedding_manager_| and
// |network_|. On error, returns non-OK.
tensorflow
::
Status
InitializeInputLinks
(
const
tensorflow
::
XlaCompiledCpuFunction
&
instance
);
// Initializes the |input_recurrences_| based on the |config|, |manager|, and
// |network_|. Requires that layers have been added to the |manager|. On
// error, returns non-OK.
tensorflow
::
Status
InitializeInputRecurrences
(
const
CellSubgraphSpec
&
cell_subgraph_spec
,
const
NetworkStateManager
&
manager
,
const
tensorflow
::
XlaCompiledCpuFunction
&
instance
);
// Initializes the |output_layers_| based on the |config|, |manager|, and
// |network_|. Adds layers to the |manager|. On error, returns non-OK.
tensorflow
::
Status
InitializeOutputLayers
(
const
CellSubgraphSpec
&
cell_subgraph_spec
,
NetworkStateManager
*
manager
,
const
tensorflow
::
XlaCompiledCpuFunction
&
instance
);
// Initializes the constant vectors (|zero_|, |one_|, and |zeros_|) and their
// backing |array_|. Requires that the |input_recurrences_| are initialized.
tensorflow
::
Status
InitializeConstantVectors
();
// Initializes the |logits_handle_| based on the |component_spec| and
// |manager|, if needed.
tensorflow
::
Status
MaybeInitializeLogits
(
const
ComponentSpec
&
component_spec
,
const
NetworkStateManager
&
manager
);
// Name of this component.
string
name_
;
// Managers for the fixed and linked embeddings used by the component.
FixedEmbeddingManager
fixed_embedding_manager_
;
LinkedEmbeddingManager
linked_embedding_manager_
;
// Fixed and linked embeddings.
SharedExtensionHandle
<
FixedEmbeddings
>
fixed_embeddings_handle_
;
SharedExtensionHandle
<
LinkedEmbeddings
>
linked_embeddings_handle_
;
// The StaticData that identifies the XLA compiled function that implements
// the network cell. Cached to reduce virtual call overhead.
const
tensorflow
::
XlaCompiledCpuFunction
::
StaticData
*
static_data_
=
nullptr
;
// Description of shapes and types of the compiled function, with indices that
// correspond to InputHandle and OutputHandle index values.
const
xla
::
ProgramShape
*
program_shape_
=
nullptr
;
// List of fixed feature ID inputs, aligned with the relevant FixedEmbeddings.
std
::
vector
<
InputId
>
input_ids_
;
// List of linked feature inputs, aligned with the relevant LinkedEmbeddings.
std
::
vector
<
InputLink
>
input_links_
;
// List of recurrent input, not ordered.
std
::
vector
<
InputRecurrence
>
input_recurrences_
;
// List of output layers, not ordered.
std
::
vector
<
OutputLayer
>
output_layers_
;
// A few constant vectors and their backing array.
UniqueAlignedArray
array_
;
Vector
<
float
>
zero_
;
// [0.0], for linked out-of-bounds indicators
Vector
<
float
>
one_
;
// [1.0], for linked out-of-bounds indicators
Vector
<
float
>
zeros_
;
// [0.0...0.0], for linked activation vectors
// Whether the transition system is deterministic.
bool
deterministic_
=
false
;
// Handle to the classification logits. Valid iff |deterministic_| is false.
LayerHandle
<
float
>
logits_handle_
;
// Compiled function that implements the network cell. Local, since each
// component can have a different cell.
LocalExtensionHandle
<
tensorflow
::
XlaCompiledCpuFunction
>
instance_handle_
;
};
// Implementation details below.
template
<
class
T
>
void
XlaDynamicComponentBase
::
BindInput
(
Vector
<
T
>
vector
,
const
InputHandle
&
handle
,
tensorflow
::
XlaCompiledCpuFunction
*
instance
)
{
DCHECK_GE
(
handle
.
index
,
0
);
DCHECK_EQ
(
reinterpret_cast
<
size_t
>
(
vector
.
data
())
%
kXlaByteAlignment
,
0
);
// Since XLA only consumes non-const pointers, const_cast() is required.
// XLA will not modify the contents of the |vector|, provided it is bound
// to a cell input.
instance
->
set_arg_data
(
handle
.
index
,
const_cast
<
void
*>
(
reinterpret_cast
<
const
void
*>
(
vector
.
data
())));
}
template
<
class
T
>
void
XlaDynamicComponentBase
::
BindOutput
(
MutableVector
<
T
>
vector
,
const
OutputHandle
&
handle
,
tensorflow
::
XlaCompiledCpuFunction
*
instance
)
{
DCHECK_GE
(
handle
.
index
,
0
);
// XLA retains control over the allocation of outputs, and the pointer
// to the output must be determined using result_data() after every call
// to Run(). The outputs are copied into the session tensors.
std
::
memcpy
(
vector
.
data
(),
instance
->
result_data
(
handle
.
index
),
handle
.
bytes
);
}
inline
void
XlaDynamicComponentBase
::
BindInputIds
(
const
FixedEmbeddings
&
fixed_embeddings
,
tensorflow
::
XlaCompiledCpuFunction
*
instance
)
const
{
for
(
size_t
i
=
0
;
i
<
input_ids_
.
size
();
++
i
)
{
BindInput
(
fixed_embeddings
.
ids
(
i
),
input_ids_
[
i
].
id
,
instance
);
}
}
inline
void
XlaDynamicComponentBase
::
BindInputLink
(
Vector
<
float
>
embedding
,
bool
is_out_of_bounds
,
const
InputLink
&
input_link
,
tensorflow
::
XlaCompiledCpuFunction
*
instance
)
const
{
BindInput
(
embedding
,
input_link
.
activations
,
instance
);
if
(
input_link
.
out_of_bounds
.
index
!=
-
1
)
{
BindInput
(
is_out_of_bounds
?
one_
:
zero_
,
input_link
.
out_of_bounds
,
instance
);
}
}
inline
void
XlaDynamicComponentBase
::
BindInputLinks
(
const
LinkedEmbeddings
&
linked_embeddings
,
tensorflow
::
XlaCompiledCpuFunction
*
instance
)
const
{
for
(
size_t
i
=
0
;
i
<
input_links_
.
size
();
++
i
)
{
BindInputLink
(
linked_embeddings
.
embedding
(
i
),
linked_embeddings
.
is_out_of_bounds
(
i
),
input_links_
[
i
],
instance
);
}
}
inline
void
XlaDynamicComponentBase
::
BindInputRecurrences
(
size_t
step_index
,
const
NetworkStates
&
network_states
,
tensorflow
::
XlaCompiledCpuFunction
*
instance
)
const
{
for
(
const
InputRecurrence
&
input
:
input_recurrences_
)
{
if
(
step_index
==
0
)
{
// The previous output is out-of-bounds, so feed a zero vector. Recall
// that |zeros_| was constructed to be large enough for any recurrence.
BindInput
(
zeros_
,
input
.
previous_output
,
instance
);
}
else
{
BindInput
(
Vector
<
float
>
(
network_states
.
GetLayer
(
input
.
handle
).
row
(
step_index
-
1
)),
input
.
previous_output
,
instance
);
}
}
}
inline
void
XlaDynamicComponentBase
::
BindOutputLayers
(
size_t
step_index
,
const
NetworkStates
&
network_states
,
tensorflow
::
XlaCompiledCpuFunction
*
instance
)
const
{
for
(
const
OutputLayer
&
output
:
output_layers_
)
{
BindOutput
(
network_states
.
GetLayer
(
output
.
handle
).
row
(
step_index
),
output
.
layer
,
instance
);
}
}
inline
tensorflow
::
XlaCompiledCpuFunction
&
XlaDynamicComponentBase
::
GetInstance
(
SessionState
*
session_state
)
const
{
return
session_state
->
extensions
.
Get
(
instance_handle_
,
*
static_data_
,
tensorflow
::
XlaCompiledCpuFunction
::
AllocMode
::
RESULTS_PROFILES_AND_TEMPS_ONLY
);
}
inline
void
XlaDynamicComponentBase
::
MaybeTrace
(
size_t
step_index
,
tensorflow
::
XlaCompiledCpuFunction
*
/*instance*/
,
ComponentTrace
*
component_trace
)
const
{
if
(
component_trace
==
nullptr
)
return
;
while
(
component_trace
->
step_trace_size
()
<=
step_index
)
{
component_trace
->
add_step_trace
();
}
// TODO(googleuser): Add once the JIT API supports this.
}
}
// namespace runtime
}
// namespace dragnn
}
// namespace syntaxnet
#endif // DRAGNN_RUNTIME_XLA_XLA_DYNAMIC_COMPONENT_BASE_H_
research/syntaxnet/dragnn/runtime/xla/xla_dynamic_component_test.cc
deleted
100644 → 0
View file @
a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include <functional>
#include <memory>
#include <string>
#include "dragnn/core/test/generic.h"
#include "dragnn/protos/cell_trace.pb.h"
#include "dragnn/protos/export.pb.h"
#include "dragnn/protos/spec.pb.h"
#include "dragnn/protos/trace.pb.h"
#include "dragnn/runtime/component.h"
#include "dragnn/runtime/extensions.h"
#include "dragnn/runtime/math/types.h"
#include "dragnn/runtime/network_states.h"
#include "dragnn/runtime/session_state.h"
#include "dragnn/runtime/test/fake_variable_store.h"
#include "dragnn/runtime/test/network_test_base.h"
#include "dragnn/runtime/type_keyed_set.h"
#include "dragnn/runtime/xla/xla_graph_utils.h"
#include "dragnn/runtime/xla/xla_spec_utils.h"
#include "syntaxnet/base.h"
#include <gmock/gmock.h>
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/node_def_builder.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/test.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
namespace
{
using
::
testing
::
_
;
using
::
testing
::
InSequence
;
using
::
testing
::
Invoke
;
using
::
testing
::
Return
;
constexpr
int
kVocabularySize
=
123
;
constexpr
int
kLogitsDim
=
11
;
constexpr
int
kNumSteps
=
50
;
class
XlaDynamicComponentTest
:
public
NetworkTestBase
{
protected:
// Options for building a GraphDef file for tests. By default, this specifies
// a working GraphDef file, but settings can be perturbed to trigger errors.
struct
GraphDefOptions
{
GraphDefOptions
()
=
default
;
// Dimension of the classification logits.
int
logits_dim
=
kLogitsDim
;
// Name of the variable containing the classification logits.
string
logits_name
=
"logits"
;
// Type of the feature ID input.
xla
::
PrimitiveType
id_type
=
xla
::
S32
;
// Dimension of the feature ID input.
int
id_dim
=
1
;
};
// Builds and writes a simple frozen GraphDef file. By default it produces a
// valid frozen GraphDef, but arguments can be overridden for error testing.
// Returns the path to the file.
static
string
WriteFrozenGraphDef
()
{
return
WriteFrozenGraphDef
(
GraphDefOptions
());
}
static
tensorflow
::
DataType
TensorFlowType
(
xla
::
PrimitiveType
type
)
{
switch
(
type
)
{
case
xla
::
S32
:
return
tensorflow
::
DT_INT32
;
case
xla
::
S64
:
return
tensorflow
::
DT_INT64
;
case
xla
::
F32
:
return
tensorflow
::
DT_FLOAT
;
default:
break
;
}
return
tensorflow
::
DT_INVALID
;
}
static
string
WriteFrozenGraphDef
(
const
GraphDefOptions
&
options
)
{
CellSubgraphSpec
spec
;
tensorflow
::
GraphDef
graph
;
// A fixed feature ID input.
auto
*
input
=
spec
.
add_input
();
input
->
set_name
(
"fixed_channel_0_index_0_ids"
);
input
->
set_tensor
(
"cell/id:0"
);
input
->
set_type
(
CellSubgraphSpec
::
Input
::
TYPE_FEATURE
);
// The retrieved embedding row, as logits.
auto
*
output
=
spec
.
add_output
();
output
->
set_name
(
options
.
logits_name
);
output
->
set_tensor
(
"cell/lookup:0"
);
// Add CellSubgraphSpec node.
tensorflow
::
Tensor
spec_tensor
(
tensorflow
::
DT_STRING
,
tensorflow
::
TensorShape
({
1
}));
spec
.
SerializeToString
(
&
spec_tensor
.
vec
<
string
>
()(
0
));
tensorflow
::
TensorProto
spec_tensor_proto
;
spec_tensor
.
AsProtoField
(
&
spec_tensor_proto
);
TF_CHECK_OK
(
tensorflow
::
NodeDefBuilder
(
kFrozenCellSubgraphSpecNodeName
,
"Const"
)
.
Attr
(
"dtype"
,
tensorflow
::
DT_STRING
)
.
Attr
(
"value"
,
spec_tensor_proto
)
.
Attr
(
"shape"
,
tensorflow
::
TensorShape
({
1
}))
.
Finalize
(
graph
.
add_node
()));
// Fixed feature ID input placeholder node.
TF_CHECK_OK
(
tensorflow
::
NodeDefBuilder
(
"cell/id"
,
"Placeholder"
)
.
Attr
(
"dtype"
,
TensorFlowType
(
options
.
id_type
))
.
Attr
(
"shape"
,
tensorflow
::
TensorShape
({
options
.
id_dim
}))
.
Finalize
(
graph
.
add_node
()));
// An embedding matrix constant. Each embedding is filled with its index.
tensorflow
::
Tensor
embeddings
(
tensorflow
::
DT_FLOAT
,
tensorflow
::
TensorShape
({
kVocabularySize
,
options
.
logits_dim
}));
auto
raw_tensor
=
embeddings
.
tensor
<
float
,
2
>
();
for
(
int
row
=
0
;
row
<
kVocabularySize
;
++
row
)
{
for
(
int
column
=
0
;
column
<
options
.
logits_dim
;
++
column
)
{
raw_tensor
(
row
,
column
)
=
row
;
}
}
tensorflow
::
TensorProto
embeddings_proto
;
embeddings
.
AsProtoTensorContent
(
&
embeddings_proto
);
TF_CHECK_OK
(
tensorflow
::
NodeDefBuilder
(
"cell/embedding_matrix"
,
"Const"
)
.
Attr
(
"dtype"
,
tensorflow
::
DT_FLOAT
)
.
Attr
(
"value"
,
embeddings_proto
)
.
Finalize
(
graph
.
add_node
()));
// A Gather op that looks up the |id| in the |embeddings|, and returns the
// result in the |logits|.
TF_CHECK_OK
(
tensorflow
::
NodeDefBuilder
(
"cell/lookup"
,
"Gather"
)
.
Input
(
"cell/embedding_matrix"
,
0
,
tensorflow
::
DT_FLOAT
)
.
Input
(
"cell/id"
,
0
,
TensorFlowType
(
options
.
id_type
))
.
Attr
(
"validate_indices"
,
true
)
.
Finalize
(
graph
.
add_node
()));
const
string
path
=
tensorflow
::
io
::
JoinPath
(
tensorflow
::
testing
::
TmpDir
(),
"graph-frozen"
);
TF_CHECK_OK
(
SaveFrozenGraphDef
(
path
,
graph
));
return
path
;
}
// Creates a component, initializes it based on the |component_spec_text| and
// |flow_path|, and evaluates it. The |component_trace| is overwritten with
// traces, if non-null. On error, returns non-OK.
tensorflow
::
Status
Run
(
const
string
&
component_spec_text
=
""
,
const
string
&
flow_path
=
WriteFrozenGraphDef
(),
ComponentTrace
*
component_trace
=
nullptr
)
{
ComponentSpec
component_spec
;
CHECK
(
TextFormat
::
ParseFromString
(
component_spec_text
,
&
component_spec
));
if
(
!
component_spec
.
has_num_actions
())
{
component_spec
.
set_num_actions
(
kLogitsDim
);
}
component_spec
.
set_name
(
kTestComponentName
);
auto
*
fixed_feature
=
component_spec
.
add_fixed_feature
();
fixed_feature
->
set_embedding_dim
(
-
1
);
fixed_feature
->
set_size
(
1
);
TF_RETURN_IF_ERROR
(
AddFrozenGraphDefResource
(
flow_path
,
&
component_spec
));
AddComponent
(
kTestComponentName
);
TF_RETURN_IF_ERROR
(
Component
::
CreateOrError
(
"XlaDynamicComponent"
,
&
component_
));
TF_RETURN_IF_ERROR
(
component_
->
Initialize
(
component_spec
,
&
variable_store_
,
&
network_state_manager_
,
&
extension_manager_
));
network_states_
.
Reset
(
&
network_state_manager_
);
StartComponent
(
0
);
// XlaDynamicComponent will add steps
session_state_
.
extensions
.
Reset
(
&
extension_manager_
);
TF_RETURN_IF_ERROR
(
component_
->
Evaluate
(
&
session_state_
,
&
compute_session_
,
component_trace
));
return
tensorflow
::
Status
::
OK
();
}
std
::
unique_ptr
<
Component
>
component_
;
};
// Tests that XlaDynamicComponent fails if the spec uses attention.
TEST_F
(
XlaDynamicComponentTest
,
UnsupportedAttention
)
{
EXPECT_THAT
(
Run
(
"attention_component:'foo'"
),
test
::
IsErrorWithSubstr
(
"Attention is not supported"
));
}
// Tests that XlaDynamicComponent fails if the spec has embedded fixed
// features.
TEST_F
(
XlaDynamicComponentTest
,
InvalidFixedFeatureIsEmbedded
)
{
EXPECT_THAT
(
Run
(
"fixed_feature { embedding_dim:1 }"
),
test
::
IsErrorWithSubstr
(
"XLA requires non-embedded fixed features"
));
}
// Tests that XlaDynamicComponent fails if the ComponentSpec has a fixed
// feature that does not appear in the graph.
TEST_F
(
XlaDynamicComponentTest
,
InvalidFixedFeatureNotInGraph
)
{
EXPECT_THAT
(
Run
(
"fixed_feature { embedding_dim:-1 size:1 }"
),
test
::
IsErrorWithSubstr
(
tensorflow
::
strings
::
StrCat
(
"No XLA tensor named '"
,
MakeXlaInputFixedFeatureIdName
(
1
,
0
),
"'"
)));
}
// Tests that XlaDynamicComponent fails if the spec has multipled linked
// features.
TEST_F
(
XlaDynamicComponentTest
,
InvalidLinkedFeatureIsMultiplied
)
{
EXPECT_THAT
(
Run
(
"linked_feature { embedding_dim:1 }"
),
test
::
IsErrorWithSubstr
(
"XLA requires non-multiplied linked features"
));
}
// Tests that XlaDynamicComponent fails if the ComponentSpec has a linked
// feature that does not appear in the graph.
TEST_F
(
XlaDynamicComponentTest
,
InvalidLinkedFeatureNotInGraph
)
{
const
string
kSpec
=
tensorflow
::
strings
::
StrCat
(
"linked_feature { source_component:'"
,
kTestComponentName
,
"' source_layer:'logits' embedding_dim:-1 size:1 }"
);
EXPECT_THAT
(
Run
(
kSpec
),
test
::
IsErrorWithSubstr
(
tensorflow
::
strings
::
StrCat
(
"No XLA tensor named '"
,
MakeXlaInputLinkedActivationVectorName
(
0
),
"'"
)));
}
// Tests that XlaDynamicComponent fails if the GraphDef file does not exist.
TEST_F
(
XlaDynamicComponentTest
,
InvalidPath
)
{
EXPECT_THAT
(
Run
(
""
,
"/invalid/path"
),
test
::
IsErrorWithSubstr
(
"No such file or directory"
));
}
// Tests that XlaDynamicComponent fails if the logits dimension does not
// match ComponentSpec.num_actions.
TEST_F
(
XlaDynamicComponentTest
,
WrongLogitsDimension
)
{
GraphDefOptions
options
;
options
.
logits_dim
=
kLogitsDim
+
1
;
EXPECT_THAT
(
Run
(
""
,
WriteFrozenGraphDef
(
options
)),
test
::
IsErrorWithSubstr
(
"Dimension mismatch between classification logits"
));
}
// Tests that XlaDynamicComponent fails if there is no "logits" layer.
TEST_F
(
XlaDynamicComponentTest
,
WrongLogitsName
)
{
GraphDefOptions
options
;
options
.
logits_name
=
"not_logits"
;
EXPECT_THAT
(
Run
(
""
,
WriteFrozenGraphDef
(
options
)),
test
::
IsErrorWithSubstr
(
"Unknown layer 'logits'"
));
}
// Tests that XlaDynamicComponent fails to compile if one of the XLA
// tensors has the wrong type.
TEST_F
(
XlaDynamicComponentTest
,
FailToCompile
)
{
GraphDefOptions
options
;
options
.
id_type
=
xla
::
F32
;
EXPECT_THAT
(
Run
(
""
,
WriteFrozenGraphDef
(
options
)),
test
::
IsErrorWithSubstr
(
"float is not in the list of allowed values"
));
}
// Tests that XlaDynamicComponent fails if one of the XLA tensors is not
// vector-like.
TEST_F
(
XlaDynamicComponentTest
,
NotVectorLike
)
{
GraphDefOptions
options
;
options
.
id_dim
=
2
;
EXPECT_THAT
(
Run
(
""
,
WriteFrozenGraphDef
(
options
)),
test
::
IsErrorWithSubstr
(
"XLA tensor has non-vector-like shape"
));
}
// Tests that XlaDynamicComponent fails if AdvanceFromPrediction() fails.
TEST_F
(
XlaDynamicComponentTest
,
FailToAdvanceFromPrediction
)
{
EXPECT_CALL
(
compute_session_
,
IsTerminal
(
_
)).
WillRepeatedly
(
Return
(
false
));
EXPECT_CALL
(
compute_session_
,
AdvanceFromPrediction
(
_
,
_
,
_
,
_
))
.
WillOnce
(
Return
(
false
));
EXPECT_CALL
(
compute_session_
,
GetInputFeatures
(
_
,
_
,
_
,
_
,
_
))
.
WillOnce
(
Invoke
(
ExtractFeatures
(
0
,
{{
10
,
1.0
}})));
EXPECT_THAT
(
Run
(),
test
::
IsErrorWithSubstr
(
"Error in ComputeSession::AdvanceFromPrediction()"
));
}
// Tests that XlaDynamicComponent can run a simple non-deterministic frozen
// GraphDef.
TEST_F
(
XlaDynamicComponentTest
,
SimpleNonDeterministicFlow
)
{
SetupTransitionLoop
(
kNumSteps
);
EXPECT_CALL
(
compute_session_
,
AdvanceFromPrediction
(
_
,
_
,
_
,
_
))
.
Times
(
kNumSteps
)
.
WillRepeatedly
(
Return
(
true
));
{
// Extract a sequence of feature IDs equal to 2 * step_index.
ASSERT_LE
(
2
*
kNumSteps
,
kVocabularySize
);
InSequence
scoped
;
for
(
int
step_index
=
0
;
step_index
<
kNumSteps
;
++
step_index
)
{
EXPECT_CALL
(
compute_session_
,
GetInputFeatures
(
_
,
_
,
_
,
_
,
_
))
.
WillOnce
(
Invoke
(
ExtractFeatures
(
0
,
{{
2
*
step_index
,
1.0
}})));
}
}
TF_ASSERT_OK
(
Run
());
const
Matrix
<
float
>
logits
(
GetLayer
(
kTestComponentName
,
"logits"
));
ASSERT_EQ
(
logits
.
num_rows
(),
kNumSteps
);
ASSERT_EQ
(
logits
.
num_columns
(),
kLogitsDim
);
// Since each row of the embedding matrix is filled with its index, the logits
// should be equal to the feature IDs.
for
(
int
step_index
=
0
;
step_index
<
kNumSteps
;
++
step_index
)
{
ExpectVector
(
logits
.
row
(
step_index
),
kLogitsDim
,
2
*
step_index
);
}
}
// Tests that XlaDynamicComponent can run a simple deterministic frozen
// GraphDef.
TEST_F
(
XlaDynamicComponentTest
,
SimpleDeterministicFlow
)
{
SetupTransitionLoop
(
kNumSteps
);
EXPECT_CALL
(
compute_session_
,
AdvanceFromOracle
(
kTestComponentName
))
.
Times
(
kNumSteps
);
{
// Extract a sequence of feature IDs equal to 2 * step_index.
ASSERT_LE
(
2
*
kNumSteps
,
kVocabularySize
);
InSequence
scoped
;
for
(
int
step_index
=
0
;
step_index
<
kNumSteps
;
++
step_index
)
{
EXPECT_CALL
(
compute_session_
,
GetInputFeatures
(
_
,
_
,
_
,
_
,
_
))
.
WillOnce
(
Invoke
(
ExtractFeatures
(
0
,
{{
2
*
step_index
,
1.0
}})));
}
}
GraphDefOptions
options
;
options
.
logits_dim
=
1
;
TF_ASSERT_OK
(
Run
(
"num_actions:1"
,
WriteFrozenGraphDef
(
options
)));
}
// Tests that XlaDynamicComponent can run a simple frozen GraphDef with tracing
// enabled.
TEST_F
(
XlaDynamicComponentTest
,
SimpleFlowWithTracing
)
{
SetupTransitionLoop
(
kNumSteps
);
EXPECT_CALL
(
compute_session_
,
AdvanceFromPrediction
(
_
,
_
,
_
,
_
))
.
Times
(
kNumSteps
)
.
WillRepeatedly
(
Return
(
true
));
{
// Extract a sequence of feature IDs equal to 2 * step_index.
ASSERT_LE
(
2
*
kNumSteps
,
kVocabularySize
);
InSequence
scoped
;
for
(
int
step_index
=
0
;
step_index
<
kNumSteps
;
++
step_index
)
{
EXPECT_CALL
(
compute_session_
,
GetInputFeatures
(
_
,
_
,
_
,
_
,
_
))
.
WillOnce
(
Invoke
(
ExtractFeatures
(
0
,
{{
2
*
step_index
,
1.0
}})));
}
}
ComponentTrace
component_trace
;
TF_ASSERT_OK
(
Run
(
""
,
WriteFrozenGraphDef
(),
&
component_trace
));
// Each step trace should have a cell trace from the XLA instance.
ASSERT_EQ
(
component_trace
.
step_trace_size
(),
kNumSteps
);
for
(
const
ComponentStepTrace
&
step_trace
:
component_trace
.
step_trace
())
{
// TODO(googleuser): Add once the JIT API supports this.
EXPECT_EQ
(
step_trace
.
ExtensionSize
(
CellTrace
::
step_trace_extension
),
0
);
}
}
}
// namespace
}
// namespace runtime
}
// namespace dragnn
}
// namespace syntaxnet
research/syntaxnet/dragnn/runtime/xla/xla_extract_config.cc
deleted
100644 → 0
View file @
a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
// Writes a file containing a text tf2xla::Config proto that is extracted
// from a frozen binary GraphDef file for a DRAGNN component.
//
// Usage: xla_extract_config input-graph-def output-config
// input-graph-def: input frozen tensorflow.GraphDef binary proto
// output-config: extracted tensorflow.tf2xla.Config text proto
#include <string.h>
#include "dragnn/protos/export.pb.h"
#include "dragnn/runtime/xla/xla_graph_utils.h"
#include "tensorflow/compiler/tf2xla/tf2xla.pb.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/init_main.h"
#include "tensorflow/core/platform/logging.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
namespace
{
// Writes the Config extracted from |input_graph_def| to |output_config|.
// On error, returns non-OK.
tensorflow
::
Status
XlaExtractConfig
(
const
char
*
input_graph_def
,
const
char
*
output_config
)
{
tensorflow
::
GraphDef
graph
;
TF_RETURN_IF_ERROR
(
LoadFrozenGraphDef
(
input_graph_def
,
&
graph
));
CellSubgraphSpec
cell_subgraph_spec
;
tensorflow
::
tf2xla
::
Config
xla_config
;
TF_RETURN_IF_ERROR
(
GetSpecAndMakeXlaConfig
(
graph
,
&
cell_subgraph_spec
,
&
xla_config
));
return
WriteTextProto
(
tensorflow
::
Env
::
Default
(),
output_config
,
xla_config
);
}
}
// namespace
}
// namespace runtime
}
// namespace dragnn
}
// namespace syntaxnet
int
main
(
int
argc
,
char
**
argv
)
{
tensorflow
::
port
::
InitMain
(
argv
[
0
],
&
argc
,
&
argv
);
if
(
argc
!=
3
||
strlen
(
argv
[
1
])
==
0
||
strlen
(
argv
[
2
])
==
0
)
{
LOG
(
FATAL
)
<<
"Usage: xla_extract_config input-graph-def output-config
\n
"
" input-graph-def: input frozen tensorflow.GraphDef binary proto
\n
"
" output-config: extracted tensorflow.tf2xla.Config text proto
\n
"
;
}
TF_CHECK_OK
(
syntaxnet
::
dragnn
::
runtime
::
XlaExtractConfig
(
argv
[
1
],
argv
[
2
]));
return
0
;
}
research/syntaxnet/dragnn/runtime/xla/xla_extract_names_from_specs.cc
deleted
100644 → 0
View file @
a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
// Writes a Bazel file containing a definition for XLA_AOT_COMPONENTS. The
// value is an array; each element is an array of strings containing information
// needed to build the XLA AOT library for a graph, and the DRAGNN component
// that uses it.
//
// This file is loaded and then used by the dragnn_xla_aot_components() build
// rule (see xla_build_defs.bzl). Its contents are verified to be current by the
// dragnn_xla_aot_bazel_test() build rule, which runs this program.
//
// This program processes a set of MasterSpecs; the benefits for processing
// a set of MasterSpecs together are:
// - only a single build rule is necessary for adding component libraries;
// - duplicates of model/components across MasterSpecs are flagged as errors.
//
// Usage: xla_extract_names_from_specs graph-base [master-spec-path]+ bazel-path
// graph-base: base path to remove on GraphDefs in MasterSpecs
// master-specs: DRAGNN model MasterSpecs (includes base-path)
// bazel-path: Bazel definition output file
#include <string>
#include <vector>
#include "dragnn/runtime/xla/xla_spec_build_utils.h"
#include "syntaxnet/base.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/init_main.h"
#include "tensorflow/core/platform/logging.h"
int
main
(
int
argc
,
char
**
argv
)
{
tensorflow
::
port
::
InitMain
(
argv
[
0
],
&
argc
,
&
argv
);
if
(
argc
<
5
)
{
LOG
(
FATAL
)
<<
"Usage: xla_extract_names_from_specs"
" graph-base [master-spec-path]+ bazel-path
\n
"
" graph-base: base path to remove on GraphDefs in MasterSpecs
\n
"
" master-specs: DRAGNN model MasterSpecs (includes base-path)
\n
"
" bazel-path: Bazel definition output file
\n
"
;
}
const
char
*
base_path
=
argv
[
1
];
std
::
vector
<
string
>
master_spec_paths
;
for
(
int
i
=
2
;
i
<
argc
-
1
;
i
++
)
{
master_spec_paths
.
push_back
(
argv
[
i
]);
}
const
string
&
bazel_path
=
argv
[
argc
-
1
];
string
bazel_def
;
tensorflow
::
strings
::
StrAppend
(
&
bazel_def
,
"
\"\"\"
Generated by xla_extract_names_from_specs. "
"Do not edit.
\"\"\"\n\n
"
);
TF_CHECK_OK
(
syntaxnet
::
dragnn
::
runtime
::
MasterSpecsToBazelDef
(
"XLA_AOT_COMPONENTS"
,
base_path
,
master_spec_paths
,
&
bazel_def
));
TF_CHECK_OK
(
tensorflow
::
WriteStringToFile
(
tensorflow
::
Env
::
Default
(),
bazel_path
,
bazel_def
));
return
0
;
}
research/syntaxnet/dragnn/runtime/xla/xla_graph_utils.cc
deleted
100644 → 0
View file @
a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/xla/xla_graph_utils.h"
#include <cstddef>
#include <map>
#include <set>
#include <utility>
#include "dragnn/runtime/xla/xla_spec_utils.h"
#include "syntaxnet/base.h"
#include "tensorflow/compiler/tf2xla/tf2xla.pb.h"
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/logging.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
const
char
*
const
kFrozenCellSubgraphSpecNodeName
=
"CellSubgraphSpec"
;
namespace
{
// Fills the TensorId fields given |tensor_name|. On error, returns non-OK.
tensorflow
::
Status
FillXlaTensorId
(
const
string
&
tensor_name
,
tensorflow
::
tf2xla
::
TensorId
*
id
)
{
string
name
;
uint32
index
;
TF_RETURN_IF_ERROR
(
ParseTensorName
(
tensor_name
,
&
name
,
&
index
));
id
->
set_node_name
(
name
);
id
->
set_output_index
(
index
);
return
tensorflow
::
Status
::
OK
();
}
// Loads the |shape| proto from the placeholder |node|. On error, returns
// non-OK.
tensorflow
::
Status
GetPlaceholderShape
(
const
tensorflow
::
NodeDef
&
node
,
tensorflow
::
TensorShapeProto
*
shape_proto
)
{
if
(
node
.
op
()
!=
"Placeholder"
)
{
return
tensorflow
::
errors
::
InvalidArgument
(
"Input node '"
,
node
.
name
(),
"' is not a Placeholder"
);
}
return
tensorflow
::
GetNodeAttr
(
node
,
"shape"
,
shape_proto
);
}
}
// namespace
tensorflow
::
Status
LoadFrozenGraphDef
(
const
string
&
frozen_graph_def_path
,
tensorflow
::
GraphDef
*
graph_def
)
{
if
(
tensorflow
::
str_util
::
EndsWith
(
frozen_graph_def_path
,
".pbtxt"
))
{
return
tensorflow
::
ReadTextProto
(
tensorflow
::
Env
::
Default
(),
frozen_graph_def_path
,
graph_def
);
}
return
tensorflow
::
ReadBinaryProto
(
tensorflow
::
Env
::
Default
(),
frozen_graph_def_path
,
graph_def
);
}
tensorflow
::
Status
SaveFrozenGraphDef
(
const
string
&
frozen_graph_def_path
,
const
tensorflow
::
GraphDef
&
graph_def
)
{
const
std
::
size_t
size
=
graph_def
.
ByteSizeLong
();
string
data
(
size
,
'\0'
);
if
(
size
>
0
)
{
tensorflow
::
protobuf
::
io
::
ArrayOutputStream
array_stream
(
&
data
[
0
],
size
);
tensorflow
::
protobuf
::
io
::
CodedOutputStream
output_stream
(
&
array_stream
);
output_stream
.
SetSerializationDeterministic
(
true
);
graph_def
.
SerializeWithCachedSizes
(
&
output_stream
);
if
(
output_stream
.
HadError
()
||
size
!=
output_stream
.
ByteCount
())
{
return
tensorflow
::
errors
::
InvalidArgument
(
"Cannot serialize GraphDef"
);
}
}
return
tensorflow
::
WriteStringToFile
(
tensorflow
::
Env
::
Default
(),
frozen_graph_def_path
,
data
);
}
tensorflow
::
Status
ParseTensorName
(
const
string
&
tensor_name
,
string
*
name
,
uint32
*
index
)
{
if
(
tensor_name
[
0
]
==
'^'
)
{
return
tensorflow
::
errors
::
InvalidArgument
(
"Cannot parse name of control input '"
,
tensor_name
,
"'"
);
}
const
auto
colon_index
=
tensor_name
.
rfind
(
':'
);
if
(
colon_index
==
string
::
npos
)
{
// no colon; assume 0
*
index
=
0
;
}
else
{
const
string
output_str
=
tensor_name
.
substr
(
colon_index
+
1
);
if
(
!
tensorflow
::
strings
::
safe_strtou32
(
output_str
,
index
))
{
return
tensorflow
::
errors
::
InvalidArgument
(
"Malformed tensor name "
,
tensor_name
);
}
}
// NB: If |colon_index| is string::npos, takes the whole string as desired.
*
name
=
tensor_name
.
substr
(
0
,
colon_index
);
return
tensorflow
::
Status
::
OK
();
}
tensorflow
::
Status
GetSpecAndMakeXlaConfig
(
const
tensorflow
::
GraphDef
&
graph_def
,
CellSubgraphSpec
*
cell_subgraph_spec
,
tensorflow
::
tf2xla
::
Config
*
xla_config
)
{
// Maps the node name to its corresponding node in the GraphDef.
std
::
map
<
string
,
const
tensorflow
::
NodeDef
*>
node_name_map
;
for
(
const
tensorflow
::
NodeDef
&
node
:
graph_def
.
node
())
{
node_name_map
[
node
.
name
()]
=
&
node
;
}
// Looks for a node called |name| in |graph_def|. If present, returns OK
// and fills in |*node|, otherwise returns non-OK.
auto
lookup_node
=
[
&
](
const
string
&
name
,
const
tensorflow
::
NodeDef
**
node
)
{
const
auto
it
=
node_name_map
.
find
(
name
);
if
(
it
==
node_name_map
.
end
())
{
return
tensorflow
::
errors
::
NotFound
(
"Cannot find node "
,
name
);
}
*
node
=
it
->
second
;
return
tensorflow
::
Status
::
OK
();
};
// Retrieves the CellSubgraphSpec from the frozen graph.
const
tensorflow
::
NodeDef
*
spec_node
=
nullptr
;
TF_RETURN_IF_ERROR
(
lookup_node
(
"CellSubgraphSpec"
,
&
spec_node
));
const
auto
value_it
=
spec_node
->
attr
().
find
(
"value"
);
if
(
value_it
==
spec_node
->
attr
().
end
())
{
return
tensorflow
::
errors
::
NotFound
(
"Cannot find CellSubgraphSpec value"
);
}
if
(
!
cell_subgraph_spec
->
ParseFromString
(
value_it
->
second
.
tensor
().
string_val
(
0
)))
{
return
tensorflow
::
errors
::
InvalidArgument
(
"Failed to parse CellSubgraphSpec"
);
}
VLOG
(
1
)
<<
"CellSubgraphSpec: "
<<
cell_subgraph_spec
->
DebugString
();
// Builds the Config feeds.
for
(
const
auto
&
input
:
cell_subgraph_spec
->
input
())
{
auto
*
feed
=
xla_config
->
add_feed
();
feed
->
set_name
(
MakeXlaInputLayerName
(
input
.
name
()));
TF_RETURN_IF_ERROR
(
FillXlaTensorId
(
input
.
tensor
(),
feed
->
mutable_id
()));
const
tensorflow
::
NodeDef
*
input_node
;
TF_RETURN_IF_ERROR
(
lookup_node
(
feed
->
id
().
node_name
(),
&
input_node
));
TF_RETURN_IF_ERROR
(
GetPlaceholderShape
(
*
input_node
,
feed
->
mutable_shape
()));
}
// Builds the Config fetches and alias map.
std
::
set
<
string
>
output_tensors
;
for
(
const
auto
&
output
:
cell_subgraph_spec
->
output
())
{
if
(
output_tensors
.
insert
(
output
.
tensor
()).
second
)
{
// The first time a tensor is encountered, this adds a fetch along with
// its name. The remaining names associated with the same tensor (aliases)
// are handled by InitializeOutputLayers.
auto
*
fetch
=
xla_config
->
add_fetch
();
fetch
->
set_name
(
MakeXlaOutputLayerName
(
output
.
name
()));
TF_RETURN_IF_ERROR
(
FillXlaTensorId
(
output
.
tensor
(),
fetch
->
mutable_id
()));
}
}
VLOG
(
1
)
<<
"Config: "
<<
xla_config
->
DebugString
();
return
tensorflow
::
Status
::
OK
();
}
}
// namespace runtime
}
// namespace dragnn
}
// namespace syntaxnet
research/syntaxnet/dragnn/runtime/xla/xla_graph_utils.h
deleted
100644 → 0
View file @
a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
// Utils for working with specifications of XLA-based DRAGNN runtime models.
#ifndef DRAGNN_RUNTIME_XLA_XLA_GRAPH_UTILS_H_
#define DRAGNN_RUNTIME_XLA_XLA_GRAPH_UTILS_H_
#include <string>
#include "dragnn/protos/export.pb.h"
#include "syntaxnet/base.h"
#include "tensorflow/compiler/tf2xla/tf2xla.pb.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/lib/core/status.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
// The name of the node in the frozen GraphSpec (for a particular component)
// that contains the serialized CellSubgraphSpec.
extern
const
char
*
const
kFrozenCellSubgraphSpecNodeName
;
// Loads a GraphDef file from the |frozen_graph_def_path| into the |graph_def|.
// Assumes binary proto unless |frozen_graph_def_path| ends with ".pbtxt", in
// which case it assumes text proto format. On error, returns non-OK.
tensorflow
::
Status
LoadFrozenGraphDef
(
const
string
&
frozen_graph_def_path
,
tensorflow
::
GraphDef
*
graph_def
);
// Saves a GraphDef |graph_def| in the file |frozen_graph_def_path|. Uses
// deterministic serialization to avoid churn due to attr map order.
// Always writes in binary format. On error, returns non-OK.
tensorflow
::
Status
SaveFrozenGraphDef
(
const
string
&
frozen_graph_def_path
,
const
tensorflow
::
GraphDef
&
graph_def
);
// Fills in |name| and |index| given the |tensor_name| of the form
// "name" or "name:index". On error, changes nothing and returns non-OK.
tensorflow
::
Status
ParseTensorName
(
const
string
&
tensor_name
,
string
*
name
,
uint32
*
index
);
// Given a frozen |graph_def|, extracts the |cell_subgraph_spec| stored within
// it, and generates the |xla_config| proto. Whenever an output tensor is
// aliased, the output in |xla_config| is taken the first occurrence of the
// tensor in |cell_subgraph_spec| (aliases are resolved in the XLA component
// in InitializeOutputLayers). On error, returns non-OK.
tensorflow
::
Status
GetSpecAndMakeXlaConfig
(
const
tensorflow
::
GraphDef
&
graph_def
,
CellSubgraphSpec
*
cell_subgraph_spec
,
tensorflow
::
tf2xla
::
Config
*
xla_config
);
}
// namespace runtime
}
// namespace dragnn
}
// namespace syntaxnet
#endif // DRAGNN_RUNTIME_XLA_XLA_GRAPH_UTILS_H_
Prev
1
…
11
12
13
14
15
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment