Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
edea2b67
Commit
edea2b67
authored
May 11, 2018
by
Terry Koo
Browse files
Remove runtime because reasons.
parent
a4bb31d0
Changes
291
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
0 additions
and
3528 deletions
+0
-3528
research/syntaxnet/dragnn/runtime/variable_store_test.cc
research/syntaxnet/dragnn/runtime/variable_store_test.cc
+0
-234
research/syntaxnet/dragnn/runtime/variable_store_wrappers.cc
research/syntaxnet/dragnn/runtime/variable_store_wrappers.cc
+0
-170
research/syntaxnet/dragnn/runtime/variable_store_wrappers.h
research/syntaxnet/dragnn/runtime/variable_store_wrappers.h
+0
-143
research/syntaxnet/dragnn/runtime/variable_store_wrappers_test.cc
.../syntaxnet/dragnn/runtime/variable_store_wrappers_test.cc
+0
-270
research/syntaxnet/dragnn/runtime/xla/BUILD
research/syntaxnet/dragnn/runtime/xla/BUILD
+0
-362
research/syntaxnet/dragnn/runtime/xla/sequence_xla_dynamic_component_mixin.h
...dragnn/runtime/xla/sequence_xla_dynamic_component_mixin.h
+0
-186
research/syntaxnet/dragnn/runtime/xla/sequence_xla_dynamic_component_mixin_test.cc
.../runtime/xla/sequence_xla_dynamic_component_mixin_test.cc
+0
-390
research/syntaxnet/dragnn/runtime/xla/testdata/simple-component-spec
...ntaxnet/dragnn/runtime/xla/testdata/simple-component-spec
+0
-24
research/syntaxnet/dragnn/runtime/xla/testdata/simple-config.pbtxt
...syntaxnet/dragnn/runtime/xla/testdata/simple-config.pbtxt
+0
-17
research/syntaxnet/dragnn/runtime/xla/testdata/simple-graph.pbtxt
.../syntaxnet/dragnn/runtime/xla/testdata/simple-graph.pbtxt
+0
-105
research/syntaxnet/dragnn/runtime/xla/testdata/xla_compilation_output/master-spec
...n/runtime/xla/testdata/xla_compilation_output/master-spec
+0
-160
research/syntaxnet/dragnn/runtime/xla/testdata/xla_compilation_output/master-spec-aot
...ntime/xla/testdata/xla_compilation_output/master-spec-aot
+0
-247
research/syntaxnet/dragnn/runtime/xla/testdata/xla_compilation_output/rnn-frozen
...nn/runtime/xla/testdata/xla_compilation_output/rnn-frozen
+0
-0
research/syntaxnet/dragnn/runtime/xla/testdata/xla_compilation_output/tagger-frozen
...runtime/xla/testdata/xla_compilation_output/tagger-frozen
+0
-0
research/syntaxnet/dragnn/runtime/xla/xla_aot_dynamic_component.h
.../syntaxnet/dragnn/runtime/xla/xla_aot_dynamic_component.h
+0
-125
research/syntaxnet/dragnn/runtime/xla/xla_aot_dynamic_component_test.cc
...xnet/dragnn/runtime/xla/xla_aot_dynamic_component_test.cc
+0
-217
research/syntaxnet/dragnn/runtime/xla/xla_build_defs.bzl
research/syntaxnet/dragnn/runtime/xla/xla_build_defs.bzl
+0
-308
research/syntaxnet/dragnn/runtime/xla/xla_cell_converter.cc
research/syntaxnet/dragnn/runtime/xla/xla_cell_converter.cc
+0
-304
research/syntaxnet/dragnn/runtime/xla/xla_cell_converter.h
research/syntaxnet/dragnn/runtime/xla/xla_cell_converter.h
+0
-152
research/syntaxnet/dragnn/runtime/xla/xla_cell_converter_test.cc
...h/syntaxnet/dragnn/runtime/xla/xla_cell_converter_test.cc
+0
-114
No files found.
Too many changes to show.
To preserve performance only
291 of 291+
files are displayed.
Plain diff
Email patch
research/syntaxnet/dragnn/runtime/variable_store_test.cc
deleted
100644 → 0
View file @
a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/variable_store.h"
#include <stddef.h>
#include <algorithm>
#include <map>
#include <memory>
#include <string>
#include <vector>
#include "dragnn/core/test/generic.h"
#include "dragnn/protos/runtime.pb.h"
#include "dragnn/runtime/test/fake_variable_store.h"
#include "dragnn/runtime/test/helpers.h"
#include "syntaxnet/base.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/platform/test.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
namespace
{
// Tests that VariableStore::Lookup() fails to retrieve a vector if the
// underlying area does not have exactly one sub-view.
TEST
(
VariableStoreTest
,
LookupEmptyVector
)
{
SimpleFakeVariableStore
store
;
Vector
<
uint32
>
vector32
;
store
.
MockLookup
<
uint32
>
({
0
},
{});
EXPECT_THAT
(
store
.
Lookup
(
"empty"
,
&
vector32
),
test
::
IsErrorWithSubstr
(
"Vector variable 'empty' should have 1 sub-view but has 0"
));
}
TEST
(
VariableStoreTest
,
LookupVectorWrongDimensions
)
{
SimpleFakeVariableStore
store
;
Vector
<
float
>
vector
;
// Dimensions should indicate number of logical elements (1), not bytes (4).
store
.
MockLookup
<
char
>
({
4
},
{{
'1'
,
'2'
,
'3'
,
'4'
}});
EXPECT_THAT
(
store
.
Lookup
(
"wrongdim_1"
,
&
vector
),
test
::
IsErrorWithSubstr
(
"Vector size (1) disagrees with dimensions[0] (4)"
));
// Missing dimensions raise errors.
store
.
MockLookup
<
char
>
({},
{{
'1'
,
'2'
,
'3'
,
'4'
}});
EXPECT_THAT
(
store
.
Lookup
(
"nodims"
,
&
vector
),
test
::
IsErrorWithSubstr
(
"Expected 1 dimensions, got 0"
));
}
// Tests that VariableStore::Lookup() fails to retrieve a vector if the
// underlying area is not divisible into elements of sizeof(T) bytes.
TEST
(
VariableStoreTest
,
LookupVector
)
{
SimpleFakeVariableStore
store
;
Vector
<
uint32
>
vector32
;
Vector
<
uint64
>
vector64
;
store
.
MockLookup
<
char
>
({
6
},
{{
'1'
,
'2'
,
'3'
,
'4'
,
'5'
,
'6'
}});
EXPECT_THAT
(
store
.
Lookup
(
"123456"
,
&
vector32
),
test
::
IsErrorWithSubstr
(
"Vector variable '123456' does not divide into elements of size 4"
));
store
.
MockLookup
<
char
>
({
6
},
{{
'1'
,
'2'
,
'3'
,
'4'
,
'5'
,
'6'
}});
EXPECT_THAT
(
store
.
Lookup
(
"123456"
,
&
vector64
),
test
::
IsErrorWithSubstr
(
"Vector variable '123456' does not divide into elements of size 8"
));
store
.
MockLookup
<
char
>
({
2
},
{{
'1'
,
'2'
,
'3'
,
'4'
,
'5'
,
'6'
,
'7'
,
'8'
}});
TF_EXPECT_OK
(
store
.
Lookup
(
"12345678"
,
&
vector32
));
EXPECT_EQ
(
vector32
.
size
(),
2
);
const
string
bytes32
(
reinterpret_cast
<
const
char
*>
(
vector32
.
data
()),
8
);
EXPECT_EQ
(
bytes32
,
"12345678"
);
store
.
MockLookup
<
uint64
>
({
1
},
{{
7777
}});
TF_EXPECT_OK
(
store
.
Lookup
(
"12345678"
,
&
vector64
));
EXPECT_EQ
(
vector64
.
size
(),
1
);
EXPECT_EQ
(
vector64
[
0
],
7777
);
}
// Tests that the VariableStore fails to lookup a matrix if its dimensions are
// mismatched.
TEST
(
VariableStoreTest
,
LookupMatrixWrongDimensions
)
{
SimpleFakeVariableStore
store
;
Matrix
<
float
>
matrix
;
// Missing dimensions raise errors.
store
.
MockLookup
<
char
>
({},
{{
'1'
,
'2'
,
'3'
,
'4'
}});
EXPECT_THAT
(
store
.
Lookup
(
"nodims"
,
&
matrix
),
test
::
IsErrorWithSubstr
(
"Expected 2 dimensions, got 0"
));
// Wrong number of columns returned.
store
.
MockLookup
<
char
>
({
1
,
2
},
{{
'1'
,
'2'
,
'3'
,
'4'
}});
EXPECT_THAT
(
store
.
Lookup
(
"wrongcols"
,
&
matrix
),
test
::
IsErrorWithSubstr
(
"Matrix columns (1) disagrees with dimensions[1] (2)"
));
// Wrong number of rows returned.
store
.
MockLookup
<
char
>
({
3
,
1
},
{{
'1'
,
'2'
,
'3'
,
'4'
}});
EXPECT_THAT
(
store
.
Lookup
(
"wrongrows"
,
&
matrix
),
test
::
IsErrorWithSubstr
(
"Matrix rows (1) disagrees with dimensions[0] (3)"
));
}
// Tests that VariableStore::Lookup() fails to retrieve a row-major matrix if
// the underlying area is not divisible into elements of sizeof(T) bytes.
TEST
(
VariableStoreTest
,
LookupRowMajorMatrix
)
{
SimpleFakeVariableStore
store
;
Matrix
<
uint32
>
matrix32
;
Matrix
<
uint64
>
matrix64
;
store
.
MockLookup
<
char
>
(
{
6
,
2
},
ReplicateRows
<
char
>
({
'1'
,
'2'
,
'3'
,
'4'
,
'5'
,
'6'
},
6
));
EXPECT_THAT
(
store
.
Lookup
(
"123456"
,
&
matrix32
),
test
::
IsErrorWithSubstr
(
"Matrix variable '123456' does not divide into elements of size 4"
));
store
.
MockLookup
<
char
>
(
{
6
,
2
},
ReplicateRows
<
char
>
({
'1'
,
'2'
,
'3'
,
'4'
,
'5'
,
'6'
},
6
));
EXPECT_THAT
(
store
.
Lookup
(
"123456"
,
&
matrix64
),
test
::
IsErrorWithSubstr
(
"Matrix variable '123456' does not divide into elements of size 8"
));
store
.
MockLookup
<
char
>
(
{
8
,
2
},
ReplicateRows
<
char
>
({
'1'
,
'2'
,
'3'
,
'4'
,
'5'
,
'6'
,
'7'
,
'8'
},
8
));
TF_EXPECT_OK
(
store
.
Lookup
(
"12345678"
,
&
matrix32
));
EXPECT_EQ
(
matrix32
.
num_rows
(),
8
);
EXPECT_EQ
(
matrix32
.
num_columns
(),
2
);
for
(
size_t
i
=
0
;
i
<
matrix32
.
num_rows
();
++
i
)
{
const
string
bytes32
(
reinterpret_cast
<
const
char
*>
(
matrix32
.
row
(
i
).
data
()),
8
);
EXPECT_EQ
(
bytes32
,
"12345678"
);
}
store
.
MockLookup
({
8
,
1
},
ReplicateRows
<
uint64
>
({
7777
},
8
));
TF_EXPECT_OK
(
store
.
Lookup
(
"12345678"
,
&
matrix64
));
EXPECT_EQ
(
matrix64
.
num_rows
(),
8
);
EXPECT_EQ
(
matrix64
.
num_columns
(),
1
);
for
(
size_t
i
=
0
;
i
<
matrix64
.
num_rows
();
++
i
)
{
EXPECT_EQ
(
matrix64
.
row
(
i
)[
0
],
7777
);
}
}
// Tests that the VariableStore fails to lookup a blocked matrix if its
// dimensions are mismatched.
TEST
(
VariableStoreTest
,
BlockedLookupWrongDimensions
)
{
SimpleFakeVariableStore
store
;
BlockedMatrix
<
float
>
matrix
;
// Missing dimensions raise errors.
store
.
MockLookup
<
char
>
({},
{{
'1'
,
'2'
,
'3'
,
'4'
}});
EXPECT_THAT
(
store
.
Lookup
(
"nodims"
,
&
matrix
),
test
::
IsErrorWithSubstr
(
"Expected 3 dimensions, got 0"
));
// Wrong number of columns returned.
store
.
MockLookup
<
char
>
({
1
,
2
,
1
},
{{
'1'
,
'2'
,
'3'
,
'4'
}});
EXPECT_THAT
(
store
.
Lookup
(
"wrongcols"
,
&
matrix
),
test
::
IsErrorWithSubstr
(
"Rows * cols (2) != area view size (1)"
));
// Wrong number of rows returned.
store
.
MockLookup
<
char
>
({
3
,
1
,
1
},
{{
'1'
,
'2'
,
'3'
,
'4'
}});
EXPECT_THAT
(
store
.
Lookup
(
"wrongrows"
,
&
matrix
),
test
::
IsErrorWithSubstr
(
"Rows * cols (3) != area view size (1)"
));
// Wrong area view size.
store
.
MockLookup
<
float
>
({
1
,
1
,
1
},
{{
1.0
f
,
2.0
f
}});
EXPECT_THAT
(
store
.
Lookup
(
"wrongviewsize"
,
&
matrix
),
test
::
IsErrorWithSubstr
(
"Area view size (8) doesn't correspond to block "
"size (1) times data type size (4)"
));
}
TEST
(
VariableStoreTest
,
DoubleBlockedLookup
)
{
// BlockedMatrix::Reset() will fail if there is any alignment padding, so we
// construct an appropriate block size.
static_assert
(
internal
::
kAlignmentBytes
%
sizeof
(
double
)
==
0
,
"Alignment requirement is too small"
);
constexpr
int
kBlockSize
=
internal
::
kAlignmentBytes
/
sizeof
(
double
);
constexpr
int
kNumSubMatrices
=
3
;
constexpr
int
kNumRows
=
10
;
constexpr
int
kNumColumns
=
kNumSubMatrices
*
kBlockSize
;
constexpr
int
kNumBlocks
=
kNumSubMatrices
*
kNumRows
;
// Fill a data matrix with consecutively increasing values.
std
::
vector
<
std
::
vector
<
double
>>
data
;
double
value
=
0.0
;
for
(
int
block
=
0
;
block
<
kNumBlocks
;
++
block
)
{
data
.
emplace_back
();
for
(
int
i
=
0
;
i
<
kBlockSize
;
++
i
)
data
.
back
().
push_back
(
value
++
);
}
SimpleFakeVariableStore
store
;
BlockedMatrix
<
double
>
matrix
;
store
.
MockLookup
<
double
>
({
kNumRows
,
kNumColumns
,
kBlockSize
},
data
);
TF_EXPECT_OK
(
store
.
Lookup
(
"small_matrix_lookup"
,
&
matrix
));
EXPECT_EQ
(
matrix
.
num_rows
(),
kNumRows
);
EXPECT_EQ
(
matrix
.
num_columns
(),
kNumColumns
);
EXPECT_EQ
(
matrix
.
block_size
(),
kBlockSize
);
EXPECT_EQ
(
matrix
.
num_vectors
(),
kNumBlocks
);
double
expected
=
0.0
;
for
(
int
i
=
0
;
i
<
kNumBlocks
;
++
i
)
{
for
(
double
value
:
matrix
.
vector
(
i
))
EXPECT_EQ
(
value
,
expected
++
);
}
}
}
// namespace
}
// namespace runtime
}
// namespace dragnn
}
// namespace syntaxnet
research/syntaxnet/dragnn/runtime/variable_store_wrappers.cc
deleted
100644 → 0
View file @
a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/variable_store_wrappers.h"
#include <algorithm>
#include <tuple>
#include <utility>
#include <vector>
#include "dragnn/runtime/flexible_matrix_kernel.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/logging.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
namespace
{
// Returns the name of the averaged version of the variable named |name|.
string
GetAveragedName
(
const
string
&
name
)
{
return
tensorflow
::
strings
::
StrCat
(
name
,
"/ExponentialMovingAverage"
);
}
// Rounds a number, |rows|, up to a multiple of |multiple|. For example,
// PadRows(6, 4) will return 8, because 8 is the nearest number after 6 that is
// divisible by 4. This method requires that |multiple| be positive. It is used
// for pre-calculating the dimension of a blocked matrix, instead of having to
// read the entire matrix.
int
PadRows
(
int
rows
,
int
multiple
)
{
DCHECK_GT
(
multiple
,
0
);
return
multiple
*
((
rows
+
multiple
-
1
)
/
multiple
);
}
// Calculates effective speed of a blocked matrix kernel. Blocked kernels may do
// a bit more calculation than necessary (since each AVX/SSE register contains
// multiple values), so their effective speed is less in those cases.
float
EffectiveGflops
(
int
rows
,
int
block_dim
,
float
base_gflops
)
{
float
padded_rows
=
PadRows
(
rows
,
block_dim
);
return
(
rows
/
padded_rows
)
*
base_gflops
;
}
}
// namespace
TryAveragedVariableStoreWrapper
::
TryAveragedVariableStoreWrapper
(
std
::
unique_ptr
<
VariableStore
>
variable_store
,
bool
allow_fallback
)
:
wrapped_variable_store_
(
std
::
move
(
variable_store
)),
allow_fallback_
(
allow_fallback
)
{}
tensorflow
::
Status
TryAveragedVariableStoreWrapper
::
Lookup
(
const
string
&
name
,
VariableSpec
::
Format
format
,
std
::
vector
<
size_t
>
*
dimensions
,
AlignedArea
*
area
)
{
tensorflow
::
Status
status
=
wrapped_variable_store_
->
Lookup
(
GetAveragedName
(
name
),
format
,
dimensions
,
area
);
if
(
status
.
ok
())
{
LOG
(
INFO
)
<<
"Using averaged variable: "
<<
GetAveragedName
(
name
);
return
status
;
}
if
(
allow_fallback_
)
{
LOG
(
INFO
)
<<
"Falling back to non-averaged variable: "
<<
name
;
return
wrapped_variable_store_
->
Lookup
(
name
,
format
,
dimensions
,
area
);
}
return
tensorflow
::
errors
::
InvalidArgument
(
"Failed to retrieve averaged variable '"
,
GetAveragedName
(
name
),
"' for variable '"
,
name
,
"': "
,
status
.
error_message
());
}
tensorflow
::
Status
TryAveragedVariableStoreWrapper
::
Close
()
{
return
wrapped_variable_store_
->
Close
();
}
CaptureUsedVariableStoreWrapper
::
CaptureUsedVariableStoreWrapper
(
std
::
unique_ptr
<
VariableStore
>
variable_store
)
:
wrapped_variable_store_
(
std
::
move
(
variable_store
))
{}
tensorflow
::
Status
CaptureUsedVariableStoreWrapper
::
Lookup
(
const
string
&
name
,
VariableSpec
::
Format
format
,
std
::
vector
<
size_t
>
*
dimensions
,
AlignedArea
*
area
)
{
tensorflow
::
Status
status
=
wrapped_variable_store_
->
Lookup
(
name
,
format
,
dimensions
,
area
);
if
(
status
.
ok
())
{
// Capture the variable if the wrapped store's Lookup() succeeds.
VariableKey
key
(
name
,
format
);
std
::
pair
<
VariableKey
,
VariableValue
>
value
(
key
,
VariableValue
(
*
dimensions
,
*
area
));
if
(
index_
.
find
(
key
)
!=
index_
.
end
())
{
variables_
[
index_
[
key
]]
=
value
;
}
else
{
index_
[
key
]
=
variables_
.
size
();
variables_
.
push_back
(
value
);
}
}
return
status
;
}
tensorflow
::
Status
CaptureUsedVariableStoreWrapper
::
Close
()
{
return
wrapped_variable_store_
->
Close
();
}
FlexibleMatrixVariableStoreWrapper
::
FlexibleMatrixVariableStoreWrapper
(
std
::
unique_ptr
<
VariableStore
>
variable_store
)
:
wrapped_variable_store_
(
std
::
move
(
variable_store
))
{}
tensorflow
::
Status
FlexibleMatrixVariableStoreWrapper
::
Lookup
(
const
string
&
name
,
VariableSpec
::
Format
format
,
std
::
vector
<
size_t
>
*
dimensions
,
AlignedArea
*
area
)
{
// Forward requests that don't match the relevant suffix.
tensorflow
::
StringPiece
name_piece
=
name
;
if
(
!
tensorflow
::
str_util
::
ConsumeSuffix
(
&
name_piece
,
FlexibleMatrixKernel
::
kSuffix
))
{
return
wrapped_variable_store_
->
Lookup
(
name
,
format
,
dimensions
,
area
);
}
const
string
basename
=
name_piece
.
ToString
();
// Fetch the non-blocked, non-transposed version of the matrix. This wrapper
// will be nested inside the capturing wrapper, so we can do multiple lookups
// without capturing more variables than we need.
Matrix
<
float
>
plain_matrix
;
TF_RETURN_IF_ERROR
(
wrapped_variable_store_
->
Lookup
(
basename
,
&
plain_matrix
));
const
int
output_dimension
=
plain_matrix
.
num_columns
();
// Performance estimates for different methods. A mix of 32/48 blocked
// matrices got 28 GFLOPS, whereas only unblocked got 2.8 GFLOPS.
using
Candidate
=
std
::
tuple
<
float
,
VariableSpec
::
Format
,
string
>
;
const
std
::
vector
<
Candidate
>
candidates
=
{
Candidate
(
2.8
f
,
VariableSpec
::
FORMAT_ROW_MAJOR_MATRIX
,
tensorflow
::
strings
::
StrCat
(
basename
,
"/transposed"
)),
Candidate
(
EffectiveGflops
(
output_dimension
,
32
,
25.0
f
),
VariableSpec
::
FORMAT_COLUMN_BLOCKED_ROW_MAJOR_MATRIX
,
tensorflow
::
strings
::
StrCat
(
basename
,
"/matrix/blocked32"
)),
Candidate
(
EffectiveGflops
(
output_dimension
,
48
,
25.0
f
),
VariableSpec
::
FORMAT_COLUMN_BLOCKED_ROW_MAJOR_MATRIX
,
tensorflow
::
strings
::
StrCat
(
basename
,
"/matrix/blocked48"
))};
const
auto
max_it
=
std
::
max_element
(
candidates
.
begin
(),
candidates
.
end
());
const
VariableSpec
::
Format
argmax_format
=
std
::
get
<
1
>
(
*
max_it
);
const
string
&
argmax_name
=
std
::
get
<
2
>
(
*
max_it
);
// The requested |format| must match the best format. If not, return error
// and wait until the proper format is requested.
if
(
format
!=
argmax_format
)
{
return
tensorflow
::
errors
::
FailedPrecondition
(
"Sub-optimal matrix format: "
,
VariableSpec
::
Format_Name
(
format
),
" ("
,
VariableSpec
::
Format_Name
(
argmax_format
),
" is best)"
);
}
return
wrapped_variable_store_
->
Lookup
(
argmax_name
,
format
,
dimensions
,
area
);
}
tensorflow
::
Status
FlexibleMatrixVariableStoreWrapper
::
Close
()
{
return
wrapped_variable_store_
->
Close
();
}
}
// namespace runtime
}
// namespace dragnn
}
// namespace syntaxnet
research/syntaxnet/dragnn/runtime/variable_store_wrappers.h
deleted
100644 → 0
View file @
a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
// A set of VariableStore wrappers that provide compositional functionality.
// These are intended for offline processing and experimentation; avoid using
// these in production, where ArrayVariableStore and its subclasses should be
// used instead.
#ifndef DRAGNN_RUNTIME_VARIABLE_STORE_WRAPPERS_H_
#define DRAGNN_RUNTIME_VARIABLE_STORE_WRAPPERS_H_
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "dragnn/protos/runtime.pb.h"
#include "dragnn/runtime/alignment.h"
#include "dragnn/runtime/variable_store.h"
#include "syntaxnet/base.h"
#include "tensorflow/core/lib/core/status.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
// A wrapper that looks for an averaged version of each variable in the wrapped
// store, and failing that optionally falls back to the non-averaged version.
class
TryAveragedVariableStoreWrapper
:
public
VariableStore
{
public:
// Wraps the |variable_store|. If |allow_fallback| is true, then when the
// averaged version is missing the non-averaged version can be substituted.
explicit
TryAveragedVariableStoreWrapper
(
std
::
unique_ptr
<
VariableStore
>
variable_store
,
bool
allow_fallback
=
false
);
// Implements VariableStore.
using
VariableStore
::
Lookup
;
// import Lookup<T>() convenience methods
tensorflow
::
Status
Lookup
(
const
string
&
name
,
VariableSpec
::
Format
format
,
std
::
vector
<
size_t
>
*
dimensions
,
AlignedArea
*
area
)
override
;
tensorflow
::
Status
Close
()
override
;
private:
// Wrapped variable store.
const
std
::
unique_ptr
<
VariableStore
>
wrapped_variable_store_
;
// Whether to allow fallback to the non-averaged variable.
const
bool
allow_fallback_
;
};
// A wrapper that captures each successfully retrieved variable. Useful for
// finding the exact set of variables used by some set of DRAGNN components.
class
CaptureUsedVariableStoreWrapper
:
public
VariableStore
{
public:
// `Variables` is a list of captured variables, in order that they are
// captured. We want to preserve the order, so that arrays are sequential in
// memory. `VariableKey` is name/format metadata used to uniquely identify
// a variable; duplicate lookups to the same variable will not capture it
// twice, and its position in the list will be the first position.
using
VariableKey
=
std
::
pair
<
string
,
VariableSpec
::
Format
>
;
using
VariableValue
=
std
::
pair
<
std
::
vector
<
size_t
>
,
AlignedArea
>
;
using
Variables
=
std
::
vector
<
std
::
pair
<
VariableKey
,
VariableValue
>>
;
// Wraps the |variable_store|.
explicit
CaptureUsedVariableStoreWrapper
(
std
::
unique_ptr
<
VariableStore
>
variable_store
);
// Implements VariableStore.
using
VariableStore
::
Lookup
;
// import Lookup<T>() convenience methods
tensorflow
::
Status
Lookup
(
const
string
&
name
,
VariableSpec
::
Format
format
,
std
::
vector
<
size_t
>
*
dimensions
,
AlignedArea
*
area
)
override
;
tensorflow
::
Status
Close
()
override
;
// Returns the current set of captured variables. The variable content in the
// returned mapping is valid while this lives.
const
Variables
&
variables
()
const
{
return
variables_
;
}
private:
// Wrapped variable store.
const
std
::
unique_ptr
<
VariableStore
>
wrapped_variable_store_
;
// Current set of captured variables.
Variables
variables_
;
// Indexes key --> position in variables_ list.
std
::
map
<
VariableKey
,
int
>
index_
;
};
// A wrapper that selects a matrix format for the FlexibleMatrixKernel. This
// could be done in the FlexibleMatrixKernel itself, but factoring it into this
// wrapper allows the selection to occur at model construction time instead of
// at model loading time.
class
FlexibleMatrixVariableStoreWrapper
:
public
VariableStore
{
public:
// Wraps the |variable_store|.
explicit
FlexibleMatrixVariableStoreWrapper
(
std
::
unique_ptr
<
VariableStore
>
variable_store
);
// Looks up the variable named |name| with format |format|, returning its
// shape in |dimensions| and its data in |area|. On error, returns non-OK.
//
// If the |name| does not end in FlexibleMatrixKernel::kSuffix, passes the
// request along to the |wrapped_variable_store_|. Otherwise, if |name| is
// "foo/<kSuffix>", estimates the throughput of the matrix "foo" in various
// formats (assuming the workload is matrix-vector multiplications), selects
// the fastest format, and returns the matrix in that format.
//
// It is an error if the selected matrix format does not match the requested
// variable |format| (e.g., non-blocked vs blocked). The FlexibleMatrixKernel
// should request the variable in all relevant variable formats, so eventually
// it will issue a request in a matching format.
tensorflow
::
Status
Lookup
(
const
string
&
name
,
VariableSpec
::
Format
format
,
std
::
vector
<
size_t
>
*
dimensions
,
AlignedArea
*
area
)
override
;
using
VariableStore
::
Lookup
;
// import Lookup<T>() convenience methods
// Implements VariableStore.
tensorflow
::
Status
Close
()
override
;
private:
// Wrapped variable store.
const
std
::
unique_ptr
<
VariableStore
>
wrapped_variable_store_
;
};
}
// namespace runtime
}
// namespace dragnn
}
// namespace syntaxnet
#endif // DRAGNN_RUNTIME_VARIABLE_STORE_WRAPPERS_H_
research/syntaxnet/dragnn/runtime/variable_store_wrappers_test.cc
deleted
100644 → 0
View file @
a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/variable_store_wrappers.h"
#include <stddef.h>
#include <map>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "dragnn/core/test/generic.h"
#include "dragnn/protos/runtime.pb.h"
#include "dragnn/runtime/flexible_matrix_kernel.h"
#include "dragnn/runtime/math/transformations.h"
#include "dragnn/runtime/math/types.h"
#include "dragnn/runtime/test/fake_variable_store.h"
#include "syntaxnet/base.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/test.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
namespace
{
// Returns a variable store with some default entries for tests. Specifically,
// "foo" has an averaged version while "bar" does not.
std
::
unique_ptr
<
VariableStore
>
NewVariableStore
()
{
std
::
unique_ptr
<
FakeVariableStore
>
store
(
new
FakeVariableStore
());
store
->
AddOrDie
(
"foo"
,
{{
1.0
,
2.0
},
//
{
3.0
,
4.0
}});
store
->
AddOrDie
(
"foo/ExponentialMovingAverage"
,
{{
10.0
,
20.0
},
//
{
30.0
,
40.0
}});
store
->
AddOrDie
(
"bar"
,
{{
10.0
,
9.0
,
8.0
},
//
{
7.0
,
6.0
,
5.0
}});
return
std
::
move
(
store
);
}
// Expects that the |vector| contains the |data|.
template
<
typename
T
>
void
ExpectVector
(
Vector
<
T
>
vector
,
const
std
::
vector
<
T
>
&
data
)
{
ASSERT_EQ
(
vector
.
size
(),
data
.
size
());
for
(
size_t
i
=
0
;
i
<
data
.
size
();
++
i
)
EXPECT_EQ
(
vector
[
i
],
data
[
i
]);
}
// Expects that the |matrix| contains the |data|.
void
ExpectMatrix
(
Matrix
<
float
>
matrix
,
const
std
::
vector
<
std
::
vector
<
float
>>
&
data
)
{
ASSERT_EQ
(
matrix
.
num_rows
(),
data
.
size
());
if
(
data
.
empty
())
return
;
ASSERT_EQ
(
matrix
.
num_columns
(),
data
[
0
].
size
());
for
(
size_t
i
=
0
;
i
<
data
.
size
();
++
i
)
ExpectVector
(
matrix
.
row
(
i
),
data
[
i
]);
}
// Tests that the averaging wrapper uses the averaged version of a variable if
// available, the non-averaged version failing that, and errors out otherwise.
TEST
(
TryAveragedVariableStoreWrapperTest
,
FallbackAllowed
)
{
TryAveragedVariableStoreWrapper
store
(
NewVariableStore
(),
/*allow_fallback=*/
true
);
Matrix
<
float
>
foo_averaged
;
Matrix
<
float
>
bar_non_averaged
;
Matrix
<
float
>
unused_matrix
;
TF_ASSERT_OK
(
store
.
Lookup
(
"foo"
,
&
foo_averaged
));
TF_ASSERT_OK
(
store
.
Lookup
(
"bar"
,
&
bar_non_averaged
));
EXPECT_THAT
(
store
.
Lookup
(
"missing"
,
&
unused_matrix
),
test
::
IsErrorWithSubstr
(
"Unknown variable"
));
TF_EXPECT_OK
(
store
.
Close
());
ExpectMatrix
(
foo_averaged
,
{{
10.0
,
20.0
},
//
{
30.0
,
40.0
}});
ExpectMatrix
(
bar_non_averaged
,
{{
10.0
,
9.0
,
8.0
},
//
{
7.0
,
6.0
,
5.0
}});
}
// As above, but with fallback disabled (the default behavior).
TEST
(
TryAveragedVariableStoreWrapperTest
,
FallbackForbidden
)
{
TryAveragedVariableStoreWrapper
store
(
NewVariableStore
());
Matrix
<
float
>
foo_averaged
;
Matrix
<
float
>
bar_non_averaged
;
Matrix
<
float
>
unused_matrix
;
TF_ASSERT_OK
(
store
.
Lookup
(
"foo"
,
&
foo_averaged
));
EXPECT_THAT
(
store
.
Lookup
(
"bar"
,
&
bar_non_averaged
),
test
::
IsErrorWithSubstr
(
"Failed to retrieve averaged variable "
"'bar/ExponentialMovingAverage' for "
"variable 'bar'"
));
EXPECT_THAT
(
store
.
Lookup
(
"missing"
,
&
unused_matrix
),
test
::
IsErrorWithSubstr
(
"Failed to retrieve averaged variable "
"'missing/ExponentialMovingAverage' for "
"variable 'missing'"
));
TF_EXPECT_OK
(
store
.
Close
());
ExpectMatrix
(
foo_averaged
,
{{
10.0
,
20.0
},
//
{
30.0
,
40.0
}});
}
// Tests that the capturing wrapper correctly records the set of variables that
// have been looked up.
TEST
(
CaptureUsedVariableStoreWrapperTest
,
Capturing
)
{
CaptureUsedVariableStoreWrapper
store
(
NewVariableStore
());
Vector
<
float
>
unused_vector
;
Matrix
<
float
>
unused_row_major_matrix
;
// Try a completely missing variable. As a failed lookup, this should not
// appear among the captured variables.
EXPECT_THAT
(
store
.
Lookup
(
"missing"
,
&
unused_vector
),
test
::
IsErrorWithSubstr
(
"Unknown variable"
));
// Look up one variable of each type.
TF_ASSERT_OK
(
store
.
Lookup
(
"foo"
,
&
unused_vector
));
TF_ASSERT_OK
(
store
.
Lookup
(
"bar"
,
&
unused_row_major_matrix
));
TF_EXPECT_OK
(
store
.
Close
());
// Check the names and formats of the captured variables.
const
auto
&
variables
=
store
.
variables
();
ASSERT_EQ
(
variables
.
size
(),
2
);
// The variables must be returned in order. Check their names and format
// first.
EXPECT_EQ
(
variables
[
0
].
first
.
first
,
"foo"
);
EXPECT_EQ
(
variables
[
0
].
first
.
second
,
VariableSpec
::
FORMAT_FLAT
);
EXPECT_EQ
(
variables
[
1
].
first
.
first
,
"bar"
);
EXPECT_EQ
(
variables
[
1
].
first
.
second
,
VariableSpec
::
FORMAT_ROW_MAJOR_MATRIX
);
// Check the content of 'foo'.
EXPECT_EQ
(
variables
[
0
].
second
.
first
,
std
::
vector
<
size_t
>
{
4
});
ExpectVector
(
Vector
<
float
>
(
variables
[
0
].
second
.
second
.
view
(
0
)),
{
1.0
,
2.0
,
3.0
,
4.0
});
// Check the content of 'bar'.
EXPECT_EQ
(
variables
[
1
].
second
.
first
,
std
::
vector
<
size_t
>
({
2
,
3
}));
ExpectMatrix
(
Matrix
<
float
>
(
variables
[
1
].
second
.
second
),
{{
10.0
,
9.0
,
8.0
},
//
{
7.0
,
6.0
,
5.0
}});
}
// Returns a variable store with some blocked and transposed matrices, for
// testing the flexible matrix wrapper.
std
::
unique_ptr
<
VariableStore
>
NewBlockedAndTransposedStore
()
{
std
::
unique_ptr
<
FakeVariableStore
>
store
(
new
FakeVariableStore
());
// A tiny matrix, which favors the non-blocked format.
store
->
AddOrDie
(
"1x1"
,
{{
1.0
}});
store
->
AddOrDie
(
"1x1/transposed"
,
{{
1.0
}});
store
->
AddOrDie
(
"1x1/matrix/blocked32"
,
{{
1.0
}});
store
->
AddOrDie
(
"1x1/matrix/blocked48"
,
{{
1.0
}});
// A matrix that is a multiple of 32, which should favor block size 32.
const
std
::
vector
<
float
>
row32
(
32
,
32.0
);
const
std
::
vector
<
std
::
vector
<
float
>>
data32
(
16
,
row32
);
store
->
AddOrDie
(
"16x32"
,
data32
);
store
->
AddOrDie
(
"16x32/transposed"
,
data32
);
store
->
AddOrDie
(
"16x32/matrix/blocked32"
,
data32
);
store
->
AddOrDie
(
"16x32/matrix/blocked48"
,
data32
);
// A matrix that is a multiple of 48, which should favor block size 48.
const
std
::
vector
<
float
>
row48
(
48
,
48.0
);
const
std
::
vector
<
std
::
vector
<
float
>>
data48
(
24
,
row48
);
store
->
AddOrDie
(
"24x48"
,
data48
);
store
->
AddOrDie
(
"24x48/transposed"
,
data48
);
store
->
AddOrDie
(
"24x48/matrix/blocked32"
,
data48
);
store
->
AddOrDie
(
"24x48/matrix/blocked48"
,
data48
);
return
std
::
move
(
store
);
}
// Expects that the |blocked_matrix| matches the |num_rows|, |num_columns|, and
// |block_size| and is filled with the |value|.
void
ExpectBlockedMatrix
(
BlockedMatrix
<
float
>
blocked_matrix
,
size_t
num_rows
,
size_t
num_columns
,
size_t
block_size
,
float
value
)
{
ASSERT_EQ
(
blocked_matrix
.
num_rows
(),
num_rows
);
ASSERT_EQ
(
blocked_matrix
.
num_columns
(),
num_columns
);
ASSERT_EQ
(
blocked_matrix
.
block_size
(),
block_size
);
const
std
::
vector
<
float
>
expected_vector
(
block_size
,
value
);
for
(
size_t
i
=
0
;
i
<
blocked_matrix
.
num_vectors
();
++
i
)
{
ExpectVector
(
blocked_matrix
.
vector
(
i
),
expected_vector
);
}
}
// Tests that the flexible matrix wrapper passes through variables that don't
// end in the right suffix.
TEST
(
FlexibleMatrixVariableStoreWrapperTest
,
PassThroughIrrelevantVariables
)
{
FlexibleMatrixVariableStoreWrapper
store
(
NewBlockedAndTransposedStore
());
Vector
<
float
>
vector
;
EXPECT_THAT
(
store
.
Lookup
(
"missing"
,
&
vector
),
test
::
IsErrorWithSubstr
(
"Unknown variable"
));
TF_ASSERT_OK
(
store
.
Lookup
(
"1x1"
,
&
vector
));
ExpectVector
(
vector
,
{
1.0
});
TF_EXPECT_OK
(
store
.
Close
());
}
// Tests that the flexible matrix wrapper selects the plain matrix format for
// tiny matrices.
TEST
(
FlexibleMatrixVariableStoreWrapperTest
,
SelectPlainMatrixFormat
)
{
FlexibleMatrixVariableStoreWrapper
store
(
NewBlockedAndTransposedStore
());
Matrix
<
float
>
plain_matrix
;
BlockedMatrix
<
float
>
blocked_matrix
;
const
string
name
=
tensorflow
::
strings
::
StrCat
(
"1x1"
,
FlexibleMatrixKernel
::
kSuffix
);
EXPECT_THAT
(
store
.
Lookup
(
name
,
&
blocked_matrix
),
test
::
IsErrorWithSubstr
(
"Sub-optimal matrix format"
));
TF_ASSERT_OK
(
store
.
Lookup
(
name
,
&
plain_matrix
));
ExpectMatrix
(
plain_matrix
,
{{
1.0
}});
TF_EXPECT_OK
(
store
.
Close
());
}
// Tests that the flexible matrix wrapper selects block size 32 for a matrix
// whose size is a multiple of 32.
TEST
(
FlexibleMatrixVariableStoreWrapperTest
,
SelectBlocked32MatrixFormat
)
{
FlexibleMatrixVariableStoreWrapper
store
(
NewBlockedAndTransposedStore
());
Matrix
<
float
>
plain_matrix
;
BlockedMatrix
<
float
>
blocked_matrix
;
const
string
name
=
tensorflow
::
strings
::
StrCat
(
"16x32"
,
FlexibleMatrixKernel
::
kSuffix
);
EXPECT_THAT
(
store
.
Lookup
(
name
,
&
plain_matrix
),
test
::
IsErrorWithSubstr
(
"Sub-optimal matrix format"
));
TF_ASSERT_OK
(
store
.
Lookup
(
name
,
&
blocked_matrix
));
ExpectBlockedMatrix
(
blocked_matrix
,
16
,
32
,
32
,
32.0
);
TF_EXPECT_OK
(
store
.
Close
());
}
// Tests that the flexible matrix wrapper selects block size 48 for a matrix
// whose size is a multiple of 48.
TEST
(
FlexibleMatrixVariableStoreWrapperTest
,
SelectBlocked48MatrixFormat
)
{
FlexibleMatrixVariableStoreWrapper
store
(
NewBlockedAndTransposedStore
());
Matrix
<
float
>
plain_matrix
;
BlockedMatrix
<
float
>
blocked_matrix
;
const
string
name
=
tensorflow
::
strings
::
StrCat
(
"24x48"
,
FlexibleMatrixKernel
::
kSuffix
);
EXPECT_THAT
(
store
.
Lookup
(
name
,
&
plain_matrix
),
test
::
IsErrorWithSubstr
(
"Sub-optimal matrix format"
));
TF_ASSERT_OK
(
store
.
Lookup
(
name
,
&
blocked_matrix
));
ExpectBlockedMatrix
(
blocked_matrix
,
24
,
48
,
48
,
48.0
);
TF_EXPECT_OK
(
store
.
Close
());
}
}
// namespace
}
// namespace runtime
}
// namespace dragnn
}
// namespace syntaxnet
research/syntaxnet/dragnn/runtime/xla/BUILD
deleted
100644 → 0
View file @
a4bb31d0
package
(
default_visibility
=
[
"//visibility:public"
])
# TODO(googleuser): Move XLA libs to dragnn/runtime when stable. Probably there
# should be a refactor with the Myelin libs since they are so similar.
load
(
"//dragnn/runtime/xla:xla_build_defs.bzl"
,
"dragnn_xla_aot_components"
,
)
load
(
"//dragnn/runtime:multiarch.bzl"
,
"dragnn_cc_multiarch_library"
,
"dragnn_cc_multiarch_test"
,
)
filegroup
(
name
=
"test_xla_compilation_output"
,
srcs
=
glob
([
"testdata/xla_compilation_output/**"
]),
)
cc_binary
(
name
=
"xla_extract_config"
,
srcs
=
[
"xla_extract_config.cc"
],
deps
=
[
":xla_graph_utils"
,
"//dragnn/protos:export_proto_cc"
,
"@org_tensorflow//tensorflow/compiler/tf2xla:tf2xla_proto"
,
"@org_tensorflow//tensorflow/core:framework"
,
"@org_tensorflow//tensorflow/core:lib"
,
"@org_tensorflow//tensorflow/core:protos_all_cc"
,
],
)
cc_binary
(
name
=
"xla_extract_names_from_specs"
,
srcs
=
[
"xla_extract_names_from_specs.cc"
],
deps
=
[
":xla_spec_build_utils"
,
"//syntaxnet:base"
,
"@org_tensorflow//tensorflow/core:lib"
,
],
)
cc_library
(
name
=
"xla_cell_converter"
,
srcs
=
[
"xla_cell_converter.cc"
],
hdrs
=
[
"xla_cell_converter.h"
],
deps
=
[
":xla_graph_utils"
,
"//dragnn/protos:export_proto_cc"
,
"//dragnn/runtime:trained_model"
,
"//syntaxnet:base"
,
"@org_tensorflow//tensorflow/core:framework_headers_lib"
,
"@org_tensorflow//tensorflow/core:lib"
,
"@org_tensorflow//tensorflow/core:protos_all_cc"
,
],
)
cc_test
(
name
=
"xla_cell_converter_test"
,
size
=
"small"
,
timeout
=
"moderate"
,
srcs
=
[
"xla_cell_converter_test.cc"
],
data
=
[
"//dragnn/runtime:test_rnn_tagger"
],
deps
=
[
":xla_cell_converter"
,
":xla_graph_utils"
,
":xla_spec_utils"
,
"//dragnn/components/syntaxnet:syntaxnet_component"
,
"//dragnn/core/test:generic"
,
"//dragnn/protos:export_proto_cc"
,
"//dragnn/runtime:alignment"
,
"//dragnn/runtime:trained_model"
,
"//syntaxnet:base"
,
"@org_tensorflow//tensorflow/compiler/tf2xla:tf2xla_proto"
,
"@org_tensorflow//tensorflow/compiler/tf2xla:xla_jit_compiled_cpu_function"
,
"@org_tensorflow//tensorflow/compiler/xla:shape_util"
,
"@org_tensorflow//tensorflow/core:lib"
,
"@org_tensorflow//tensorflow/core:protos_all_cc"
,
"@org_tensorflow//tensorflow/core:test"
,
],
)
cc_library
(
name
=
"xla_compilation"
,
srcs
=
[
"xla_compilation.cc"
],
hdrs
=
[
"xla_compilation.h"
],
deps
=
[
":xla_cell_converter"
,
":xla_graph_utils"
,
":xla_spec_utils"
,
"//dragnn/protos:export_proto_cc"
,
"//dragnn/protos:spec_proto_cc"
,
"//dragnn/runtime:component"
,
"//dragnn/runtime:trained_model"
,
"//syntaxnet:base"
,
"@org_tensorflow//tensorflow/core:lib"
,
"@org_tensorflow//tensorflow/core:protos_all_cc"
,
],
)
cc_test
(
name
=
"xla_compilation_test"
,
size
=
"small"
,
timeout
=
"moderate"
,
srcs
=
[
"xla_compilation_test.cc"
],
data
=
[
":test_xla_compilation_output"
,
"//dragnn/runtime:test_rnn_tagger"
,
],
deps
=
[
":xla_compilation"
,
":xla_spec_utils"
,
"//dragnn/components/syntaxnet:syntaxnet_component"
,
"//dragnn/core/test:generic"
,
"//dragnn/protos:spec_proto_cc"
,
"//syntaxnet:base"
,
"@org_tensorflow//tensorflow/core:lib"
,
"@org_tensorflow//tensorflow/core:test"
,
],
)
dragnn_cc_multiarch_library
(
name
=
"xla_dynamic_component_base"
,
srcs
=
[
"xla_dynamic_component_base.cc"
],
hdrs
=
[
"xla_dynamic_component_base.h"
],
deps
=
[
":xla_spec_utils"
,
"//dragnn/protos:export_proto_cc"
,
"//dragnn/protos:spec_proto_cc"
,
"//dragnn/protos:trace_proto_cc"
,
"//dragnn/runtime:alignment"
,
"//dragnn/runtime:component"
,
"//dragnn/runtime:extensions"
,
"//dragnn/runtime:fixed_embeddings"
,
"//dragnn/runtime:linked_embeddings"
,
"//dragnn/runtime:network_states"
,
"//dragnn/runtime:session_state"
,
"//dragnn/runtime:transition_system_traits"
,
"//dragnn/runtime:type_keyed_set"
,
"//dragnn/runtime:variable_store"
,
"//dragnn/runtime/math:types"
,
"//syntaxnet:base"
,
"@org_tensorflow//tensorflow/compiler/tf2xla:xla_compiled_cpu_function"
,
"@org_tensorflow//tensorflow/compiler/xla:shape_util"
,
"@org_tensorflow//tensorflow/compiler/xla:xla_data_proto"
,
"@org_tensorflow//tensorflow/core:framework_headers_lib"
,
"@org_tensorflow//tensorflow/core:lib"
,
"@org_tensorflow//tensorflow/core:protos_all_cc"
,
],
)
dragnn_cc_multiarch_library
(
name
=
"sequence_xla_dynamic_component_mixin"
,
hdrs
=
[
"sequence_xla_dynamic_component_mixin.h"
],
deps
=
[
":xla_dynamic_component_base"
,
"//dragnn/core:compute_session"
,
"//dragnn/protos:spec_proto_cc"
,
"//dragnn/protos:trace_proto_cc"
,
"//dragnn/runtime:extensions"
,
"//dragnn/runtime:network_states"
,
"//dragnn/runtime:sequence_features"
,
"//dragnn/runtime:sequence_links"
,
"//dragnn/runtime:sequence_model"
,
"//dragnn/runtime:session_state"
,
"//dragnn/runtime:variable_store"
,
"//dragnn/runtime/math:types"
,
"//syntaxnet:base"
,
"@org_tensorflow//tensorflow/compiler/tf2xla:xla_compiled_cpu_function"
,
"@org_tensorflow//tensorflow/core:lib"
,
],
)
dragnn_cc_multiarch_test
(
name
=
"sequence_xla_dynamic_component_mixin_test"
,
size
=
"small"
,
srcs
=
[
"sequence_xla_dynamic_component_mixin_test.cc"
],
deps
=
[
":xla_dynamic_component"
,
":xla_graph_utils"
,
":xla_spec_utils"
,
"//dragnn/core:input_batch_cache"
,
"//dragnn/core/test:generic"
,
"//dragnn/protos:cell_trace_proto_cc"
,
"//dragnn/protos:export_proto_cc"
,
"//dragnn/protos:spec_proto_cc"
,
"//dragnn/protos:trace_proto_cc"
,
"//dragnn/runtime:component"
,
"//dragnn/runtime:extensions"
,
"//dragnn/runtime:network_states"
,
"//dragnn/runtime:sequence_backend"
,
"//dragnn/runtime:sequence_extractor"
,
"//dragnn/runtime:sequence_predictor"
,
"//dragnn/runtime/math:types"
,
"//dragnn/runtime/test:network_test_base"
,
"//syntaxnet:base"
,
"@org_tensorflow//tensorflow/compiler/xla:xla_data_proto"
,
"@org_tensorflow//tensorflow/core:framework_headers_lib"
,
"@org_tensorflow//tensorflow/core:lib"
,
"@org_tensorflow//tensorflow/core:protos_all_cc"
,
"@org_tensorflow//tensorflow/core:test"
,
],
)
dragnn_cc_multiarch_library
(
name
=
"xla_aot_dynamic_component"
,
hdrs
=
[
"xla_aot_dynamic_component.h"
],
deps
=
[
":sequence_xla_dynamic_component_mixin"
,
":xla_dynamic_component_base"
,
":xla_graph_utils"
,
":xla_spec_utils"
,
"//dragnn/protos:export_proto_cc"
,
"//dragnn/protos:spec_proto_cc"
,
"//dragnn/runtime:component"
,
"//syntaxnet:base"
,
"@org_tensorflow//tensorflow/compiler/tf2xla:tf2xla_proto"
,
"@org_tensorflow//tensorflow/compiler/tf2xla:xla_compiled_cpu_function"
,
"@org_tensorflow//tensorflow/core:lib"
,
"@org_tensorflow//tensorflow/core:protos_all_cc"
,
],
)
dragnn_cc_multiarch_library
(
name
=
"xla_dynamic_component"
,
srcs
=
[
"xla_dynamic_component.cc"
],
deps
=
[
":sequence_xla_dynamic_component_mixin"
,
":xla_dynamic_component_base"
,
":xla_graph_utils"
,
":xla_spec_utils"
,
"//dragnn/core:compute_session"
,
"//dragnn/protos:export_proto_cc"
,
"//dragnn/protos:spec_proto_cc"
,
"//dragnn/protos:trace_proto_cc"
,
"//dragnn/runtime:component"
,
"//dragnn/runtime:fixed_embeddings"
,
"//dragnn/runtime:linked_embeddings"
,
"//dragnn/runtime:network_states"
,
"//dragnn/runtime:session_state"
,
"//dragnn/runtime/math:types"
,
"//syntaxnet:base"
,
"@org_tensorflow//tensorflow/compiler/tf2xla:tf2xla_proto"
,
"@org_tensorflow//tensorflow/compiler/tf2xla:xla_compiled_cpu_function"
,
"@org_tensorflow//tensorflow/compiler/tf2xla:xla_jit_compiled_cpu_function"
,
"@org_tensorflow//tensorflow/core:lib"
,
"@org_tensorflow//tensorflow/core:protos_all_cc"
,
],
alwayslink
=
1
,
)
dragnn_cc_multiarch_test
(
name
=
"xla_dynamic_component_test"
,
size
=
"small"
,
srcs
=
[
"xla_dynamic_component_test.cc"
],
deps
=
[
":xla_dynamic_component"
,
":xla_graph_utils"
,
":xla_spec_utils"
,
"//dragnn/core/test:generic"
,
"//dragnn/protos:cell_trace_proto_cc"
,
"//dragnn/protos:export_proto_cc"
,
"//dragnn/protos:spec_proto_cc"
,
"//dragnn/protos:trace_proto_cc"
,
"//dragnn/runtime:component"
,
"//dragnn/runtime:extensions"
,
"//dragnn/runtime:network_states"
,
"//dragnn/runtime:session_state"
,
"//dragnn/runtime:type_keyed_set"
,
"//dragnn/runtime/math:types"
,
"//dragnn/runtime/test:fake_variable_store"
,
"//dragnn/runtime/test:network_test_base"
,
"//syntaxnet:base"
,
"@org_tensorflow//tensorflow/compiler/xla:xla_data_proto"
,
"@org_tensorflow//tensorflow/core:framework_headers_lib"
,
"@org_tensorflow//tensorflow/core:lib"
,
"@org_tensorflow//tensorflow/core:protos_all_cc"
,
"@org_tensorflow//tensorflow/core:test"
,
],
)
cc_library
(
name
=
"xla_graph_utils"
,
srcs
=
[
"xla_graph_utils.cc"
],
hdrs
=
[
"xla_graph_utils.h"
],
deps
=
[
":xla_spec_utils"
,
"//dragnn/protos:export_proto_cc"
,
"//syntaxnet:base"
,
"@org_tensorflow//tensorflow/compiler/tf2xla:tf2xla_proto"
,
"@org_tensorflow//tensorflow/core:framework_headers_lib"
,
"@org_tensorflow//tensorflow/core:lib"
,
"@org_tensorflow//tensorflow/core:protos_all_cc"
,
],
)
cc_test
(
name
=
"xla_graph_utils_test"
,
srcs
=
[
"xla_graph_utils_test.cc"
],
deps
=
[
":xla_graph_utils"
,
"//dragnn/core/test:generic"
,
"//dragnn/protos:export_proto_cc"
,
"//syntaxnet:base"
,
"@org_tensorflow//tensorflow/compiler/tf2xla:tf2xla_proto"
,
"@org_tensorflow//tensorflow/core:framework_headers_lib"
,
"@org_tensorflow//tensorflow/core:lib"
,
"@org_tensorflow//tensorflow/core:protos_all_cc"
,
"@org_tensorflow//tensorflow/core:test"
,
],
)
cc_library
(
name
=
"xla_spec_build_utils"
,
srcs
=
[
"xla_spec_build_utils.cc"
],
hdrs
=
[
"xla_spec_build_utils.h"
],
deps
=
[
":xla_spec_utils"
,
"//dragnn/protos:spec_proto_cc"
,
"//syntaxnet:base"
,
"@org_tensorflow//tensorflow/core:lib"
,
],
)
cc_test
(
name
=
"xla_spec_build_utils_test"
,
srcs
=
[
"xla_spec_build_utils_test.cc"
],
deps
=
[
":xla_spec_build_utils"
,
"//dragnn/core/test:generic"
,
"//dragnn/protos:export_proto_cc"
,
"//dragnn/protos:spec_proto_cc"
,
"@org_tensorflow//tensorflow/core:lib"
,
"@org_tensorflow//tensorflow/core:test"
,
],
)
cc_library
(
name
=
"xla_spec_utils"
,
srcs
=
[
"xla_spec_utils.cc"
],
hdrs
=
[
"xla_spec_utils.h"
],
deps
=
[
"//dragnn/protos:export_proto_cc"
,
"//dragnn/protos:spec_proto_cc"
,
"//syntaxnet:base"
,
"@org_tensorflow//tensorflow/core:lib"
,
],
)
cc_test
(
name
=
"xla_spec_utils_test"
,
srcs
=
[
"xla_spec_utils_test.cc"
],
deps
=
[
":xla_spec_utils"
,
"//dragnn/core/test:generic"
,
"//dragnn/protos:spec_proto_cc"
,
"//syntaxnet:base"
,
"@org_tensorflow//tensorflow/core:lib"
,
"@org_tensorflow//tensorflow/core:test"
,
],
)
research/syntaxnet/dragnn/runtime/xla/sequence_xla_dynamic_component_mixin.h
deleted
100644 → 0
View file @
a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#ifndef DRAGNN_RUNTIME_XLA_SEQUENCE_XLA_DYNAMIC_COMPONENT_MIXIN_H_
#define DRAGNN_RUNTIME_XLA_SEQUENCE_XLA_DYNAMIC_COMPONENT_MIXIN_H_
#include <stddef.h>
#include <string>
#include <type_traits>
#include <vector>
#include "dragnn/core/compute_session.h"
#include "dragnn/protos/spec.pb.h"
#include "dragnn/protos/trace.pb.h"
#include "dragnn/runtime/extensions.h"
#include "dragnn/runtime/math/types.h"
#include "dragnn/runtime/network_states.h"
#include "dragnn/runtime/sequence_features.h"
#include "dragnn/runtime/sequence_links.h"
#include "dragnn/runtime/sequence_model.h"
#include "dragnn/runtime/session_state.h"
#include "dragnn/runtime/variable_store.h"
#include "dragnn/runtime/xla/xla_dynamic_component_base.h"
#include "syntaxnet/base.h"
#include "tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/strings/str_util.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
// A mixin that converts an XlaDynamicComponent variant into a sequence-based
// version. The |Base| must be a subclass of XlaDynamicComponentBase.
template
<
class
Base
>
class
SequenceXlaDynamicComponentMixin
:
public
Base
{
public:
static_assert
(
std
::
is_base_of
<
XlaDynamicComponentBase
,
Base
>::
value
,
"SequenceXlaDynamicComponentMixin must template on a subclass "
"of XlaDynamicComponentBase"
);
// Implements Component.
bool
Supports
(
const
ComponentSpec
&
component_spec
,
const
string
&
normalized_builder_name
)
const
override
;
tensorflow
::
Status
Initialize
(
const
ComponentSpec
&
component_spec
,
VariableStore
*
variable_store
,
NetworkStateManager
*
network_state_manager
,
ExtensionManager
*
extension_manager
)
override
;
tensorflow
::
Status
Evaluate
(
SessionState
*
session_state
,
ComputeSession
*
compute_session
,
ComponentTrace
*
component_trace
)
const
override
;
private:
// Binds the fixed feature IDs for the |target_index|'th element of the
// |features| to the |instance|. Uses locals in the |network_states|.
void
BindInputIds
(
const
SequenceFeatures
&
features
,
int
target_index
,
const
NetworkStates
&
network_states
,
tensorflow
::
XlaCompiledCpuFunction
*
instance
)
const
;
// Binds the linked embeddings for the |target_index|'th element in the
// |links| to the |instance|.
void
BindInputLinks
(
const
SequenceLinks
&
links
,
int
target_index
,
tensorflow
::
XlaCompiledCpuFunction
*
instance
)
const
;
// Sequence-based model evaluator.
SequenceModel
sequence_model_
;
// Intermediate values used by sequence models.
SharedExtensionHandle
<
SequenceModel
::
EvaluateState
>
evaluate_state_handle_
;
};
template
<
class
Base
>
bool
SequenceXlaDynamicComponentMixin
<
Base
>::
Supports
(
const
ComponentSpec
&
component_spec
,
const
string
&
normalized_builder_name
)
const
{
tensorflow
::
StringPiece
name
=
normalized_builder_name
;
return
tensorflow
::
str_util
::
ConsumePrefix
(
&
name
,
"Sequence"
)
&&
Base
::
Supports
(
component_spec
,
name
.
ToString
())
&&
SequenceModel
::
Supports
(
component_spec
);
}
template
<
class
Base
>
tensorflow
::
Status
SequenceXlaDynamicComponentMixin
<
Base
>::
Initialize
(
const
ComponentSpec
&
component_spec
,
VariableStore
*
variable_store
,
NetworkStateManager
*
network_state_manager
,
ExtensionManager
*
extension_manager
)
{
// Initialize the base class first, so its FixedEmbeddingManager and
// LinkedEmbeddingManager can be wrapped in sequence-based versions.
TF_RETURN_IF_ERROR
(
Base
::
Initialize
(
component_spec
,
variable_store
,
network_state_manager
,
extension_manager
));
TF_RETURN_IF_ERROR
(
sequence_model_
.
Initialize
(
component_spec
,
Base
::
kLogitsName
,
&
Base
::
fixed_embedding_manager
(),
&
Base
::
linked_embedding_manager
(),
network_state_manager
));
extension_manager
->
GetShared
(
&
evaluate_state_handle_
);
return
tensorflow
::
Status
::
OK
();
}
template
<
class
Base
>
void
SequenceXlaDynamicComponentMixin
<
Base
>::
BindInputIds
(
const
SequenceFeatures
&
features
,
int
target_index
,
const
NetworkStates
&
network_states
,
tensorflow
::
XlaCompiledCpuFunction
*
instance
)
const
{
for
(
size_t
channel_id
=
0
;
channel_id
<
features
.
num_channels
();
++
channel_id
)
{
const
MutableVector
<
int32
>
id_vector
=
network_states
.
GetLocal
(
Base
::
fixed_embedding_manager
().
id_handle
(
channel_id
,
0
));
id_vector
[
0
]
=
features
.
GetId
(
channel_id
,
target_index
);
Base
::
BindInput
(
Vector
<
int32
>
(
id_vector
),
Base
::
input_ids
()[
channel_id
].
id
,
instance
);
}
}
template
<
class
Base
>
void
SequenceXlaDynamicComponentMixin
<
Base
>::
BindInputLinks
(
const
SequenceLinks
&
links
,
int
target_index
,
tensorflow
::
XlaCompiledCpuFunction
*
instance
)
const
{
Vector
<
float
>
embedding
;
bool
is_out_of_bounds
=
false
;
for
(
size_t
channel_id
=
0
;
channel_id
<
links
.
num_channels
();
++
channel_id
)
{
links
.
Get
(
channel_id
,
target_index
,
&
embedding
,
&
is_out_of_bounds
);
Base
::
BindInputLink
(
embedding
,
is_out_of_bounds
,
Base
::
input_links
()[
channel_id
],
instance
);
}
}
template
<
class
Base
>
tensorflow
::
Status
SequenceXlaDynamicComponentMixin
<
Base
>::
Evaluate
(
SessionState
*
session_state
,
ComputeSession
*
compute_session
,
ComponentTrace
*
component_trace
)
const
{
NetworkStates
&
network_states
=
session_state
->
network_states
;
SequenceModel
::
EvaluateState
&
state
=
session_state
->
extensions
.
Get
(
evaluate_state_handle_
);
TF_RETURN_IF_ERROR
(
sequence_model_
.
Preprocess
(
session_state
,
compute_session
,
&
state
));
// Avoid ComputeSession overhead by directly iterating over the feature IDs.
// Handle forward and reverse iteration via an index and increment.
int
target_index
=
sequence_model_
.
left_to_right
()
?
0
:
state
.
num_steps
-
1
;
const
int
target_increment
=
sequence_model_
.
left_to_right
()
?
1
:
-
1
;
tensorflow
::
XlaCompiledCpuFunction
&
instance
=
Base
::
GetInstance
(
session_state
);
for
(
size_t
step_index
=
0
;
step_index
<
state
.
num_steps
;
++
step_index
,
target_index
+=
target_increment
)
{
// Bind inputs and outputs into the |instance|.
BindInputIds
(
state
.
features
,
target_index
,
network_states
,
&
instance
);
BindInputLinks
(
state
.
links
,
target_index
,
&
instance
);
Base
::
BindInputRecurrences
(
step_index
,
network_states
,
&
instance
);
// Invoke the cell in the |instance|.
if
(
!
instance
.
Run
())
{
return
tensorflow
::
errors
::
Internal
(
"Error executing cell for "
,
Base
::
name
(),
": "
,
instance
.
error_msg
());
}
// Realizes the binding: copy outputs out of the |instance|.
Base
::
BindOutputLayers
(
step_index
,
network_states
,
&
instance
);
Base
::
MaybeTrace
(
step_index
,
&
instance
,
component_trace
);
}
return
sequence_model_
.
Predict
(
network_states
,
&
state
);
}
}
// namespace runtime
}
// namespace dragnn
}
// namespace syntaxnet
#endif // DRAGNN_RUNTIME_XLA_SEQUENCE_XLA_DYNAMIC_COMPONENT_MIXIN_H_
research/syntaxnet/dragnn/runtime/xla/sequence_xla_dynamic_component_mixin_test.cc
deleted
100644 → 0
View file @
a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include <memory>
#include <string>
#include "dragnn/core/input_batch_cache.h"
#include "dragnn/core/test/generic.h"
#include "dragnn/protos/cell_trace.pb.h"
#include "dragnn/protos/export.pb.h"
#include "dragnn/protos/spec.pb.h"
#include "dragnn/protos/trace.pb.h"
#include "dragnn/runtime/component.h"
#include "dragnn/runtime/extensions.h"
#include "dragnn/runtime/math/types.h"
#include "dragnn/runtime/network_states.h"
#include "dragnn/runtime/sequence_backend.h"
#include "dragnn/runtime/sequence_extractor.h"
#include "dragnn/runtime/sequence_predictor.h"
#include "dragnn/runtime/test/network_test_base.h"
#include "dragnn/runtime/xla/xla_graph_utils.h"
#include "dragnn/runtime/xla/xla_spec_utils.h"
#include "syntaxnet/base.h"
#include <gmock/gmock.h>
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/node_def_builder.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/test.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
namespace
{
using
::
testing
::
Return
;
constexpr
int
kVocabularySize
=
123
;
constexpr
int
kLogitsDim
=
11
;
constexpr
int
kNumSteps
=
50
;
// Sequence extractor that extracts [0, 2, 4, ...].
class
EvenNumbers
:
public
SequenceExtractor
{
public:
// Implements SequenceExtractor.
bool
Supports
(
const
FixedFeatureChannel
&
,
const
ComponentSpec
&
)
const
override
{
return
true
;
}
tensorflow
::
Status
Initialize
(
const
FixedFeatureChannel
&
,
const
ComponentSpec
&
)
override
{
return
tensorflow
::
Status
::
OK
();
}
tensorflow
::
Status
GetIds
(
InputBatchCache
*
,
std
::
vector
<
int32
>
*
ids
)
const
override
{
ids
->
clear
();
for
(
int
i
=
0
;
i
<
kNumSteps
;
++
i
)
ids
->
push_back
(
2
*
i
);
return
tensorflow
::
Status
::
OK
();
}
};
DRAGNN_RUNTIME_REGISTER_SEQUENCE_EXTRACTOR
(
EvenNumbers
);
// Trivial predictor that does nothing.
class
NoPredictions
:
public
SequencePredictor
{
public:
// Implements SequenceLinker.
bool
Supports
(
const
ComponentSpec
&
)
const
override
{
return
true
;
}
tensorflow
::
Status
Initialize
(
const
ComponentSpec
&
)
override
{
return
tensorflow
::
Status
::
OK
();
}
tensorflow
::
Status
Predict
(
Matrix
<
float
>
,
InputBatchCache
*
)
const
override
{
return
tensorflow
::
Status
::
OK
();
}
};
DRAGNN_RUNTIME_REGISTER_SEQUENCE_PREDICTOR
(
NoPredictions
);
class
SequenceXlaDynamicComponentMixinTest
:
public
NetworkTestBase
{
protected:
SequenceXlaDynamicComponentMixinTest
()
{
EXPECT_CALL
(
compute_session_
,
GetInputBatchCache
())
.
WillRepeatedly
(
Return
(
&
input_
));
EXPECT_CALL
(
compute_session_
,
GetReadiedComponent
(
kTestComponentName
))
.
WillRepeatedly
(
Return
(
&
backend_
));
}
// Options for building a GraphDef file for tests. By default, this specifies
// a working GraphDef file, but settings can be perturbed to trigger errors.
struct
GraphDefOptions
{
GraphDefOptions
()
=
default
;
// Dimension of the classification logits.
int
logits_dim
=
kLogitsDim
;
// Name of the variable containing the classification logits.
string
logits_name
=
"logits"
;
// Type of the feature ID input.
xla
::
PrimitiveType
id_type
=
xla
::
S32
;
// Dimension of the feature ID input.
int
id_dim
=
1
;
};
// Builds and writes a simple frozen GraphDef file. By default it produces a
// valid frozen GraphDef, but arguments can be overridden for error testing.
// Returns the path to the file.
static
string
WriteFrozenGraphDef
()
{
return
WriteFrozenGraphDef
(
GraphDefOptions
());
}
static
tensorflow
::
DataType
TensorFlowType
(
xla
::
PrimitiveType
type
)
{
switch
(
type
)
{
case
xla
::
S32
:
return
tensorflow
::
DT_INT32
;
case
xla
::
S64
:
return
tensorflow
::
DT_INT64
;
case
xla
::
F32
:
return
tensorflow
::
DT_FLOAT
;
default:
break
;
}
return
tensorflow
::
DT_INVALID
;
}
static
string
WriteFrozenGraphDef
(
const
GraphDefOptions
&
options
)
{
CellSubgraphSpec
spec
;
tensorflow
::
GraphDef
graph
;
// A fixed feature ID input.
auto
*
input
=
spec
.
add_input
();
input
->
set_name
(
"fixed_channel_0_index_0_ids"
);
input
->
set_tensor
(
"cell/id:0"
);
input
->
set_type
(
CellSubgraphSpec
::
Input
::
TYPE_FEATURE
);
// The retrieved embedding row, as logits.
auto
*
output
=
spec
.
add_output
();
output
->
set_name
(
options
.
logits_name
);
output
->
set_tensor
(
"cell/lookup:0"
);
// Add CellSubgraphSpec node.
tensorflow
::
Tensor
spec_tensor
(
tensorflow
::
DT_STRING
,
tensorflow
::
TensorShape
({
1
}));
spec
.
SerializeToString
(
&
spec_tensor
.
vec
<
string
>
()(
0
));
tensorflow
::
TensorProto
spec_tensor_proto
;
spec_tensor
.
AsProtoField
(
&
spec_tensor_proto
);
TF_CHECK_OK
(
tensorflow
::
NodeDefBuilder
(
kFrozenCellSubgraphSpecNodeName
,
"Const"
)
.
Attr
(
"dtype"
,
tensorflow
::
DT_STRING
)
.
Attr
(
"value"
,
spec_tensor_proto
)
.
Attr
(
"shape"
,
tensorflow
::
TensorShape
({
1
}))
.
Finalize
(
graph
.
add_node
()));
// Fixed feature ID input placeholder node.
TF_CHECK_OK
(
tensorflow
::
NodeDefBuilder
(
"cell/id"
,
"Placeholder"
)
.
Attr
(
"dtype"
,
TensorFlowType
(
options
.
id_type
))
.
Attr
(
"shape"
,
tensorflow
::
TensorShape
({
options
.
id_dim
}))
.
Finalize
(
graph
.
add_node
()));
// An embedding matrix constant. Each embedding is filled with its index.
tensorflow
::
Tensor
embeddings
(
tensorflow
::
DT_FLOAT
,
tensorflow
::
TensorShape
({
kVocabularySize
,
options
.
logits_dim
}));
auto
raw_tensor
=
embeddings
.
tensor
<
float
,
2
>
();
for
(
int
row
=
0
;
row
<
kVocabularySize
;
++
row
)
{
for
(
int
column
=
0
;
column
<
options
.
logits_dim
;
++
column
)
{
raw_tensor
(
row
,
column
)
=
row
;
}
}
tensorflow
::
TensorProto
embeddings_proto
;
embeddings
.
AsProtoTensorContent
(
&
embeddings_proto
);
TF_CHECK_OK
(
tensorflow
::
NodeDefBuilder
(
"cell/embedding_matrix"
,
"Const"
)
.
Attr
(
"dtype"
,
tensorflow
::
DT_FLOAT
)
.
Attr
(
"value"
,
embeddings_proto
)
.
Finalize
(
graph
.
add_node
()));
// A Gather op that looks up the |id| in the |embeddings|, and returns the
// result in the |logits|.
TF_CHECK_OK
(
tensorflow
::
NodeDefBuilder
(
"cell/lookup"
,
"Gather"
)
.
Input
(
"cell/embedding_matrix"
,
0
,
tensorflow
::
DT_FLOAT
)
.
Input
(
"cell/id"
,
0
,
TensorFlowType
(
options
.
id_type
))
.
Attr
(
"validate_indices"
,
true
)
.
Finalize
(
graph
.
add_node
()));
const
string
path
=
tensorflow
::
io
::
JoinPath
(
tensorflow
::
testing
::
TmpDir
(),
"graph-frozen"
);
TF_CHECK_OK
(
SaveFrozenGraphDef
(
path
,
graph
));
return
path
;
}
// Creates a component, initializes it based on the |component_spec_text| and
// |flow_path|, and evaluates it. The |component_trace| is overwritten with
// traces, if non-null. On error, returns non-OK.
tensorflow
::
Status
Run
(
const
string
&
component_spec_text
=
""
,
const
string
&
flow_path
=
WriteFrozenGraphDef
(),
ComponentTrace
*
component_trace
=
nullptr
)
{
ComponentSpec
component_spec
;
CHECK
(
TextFormat
::
ParseFromString
(
component_spec_text
,
&
component_spec
));
if
(
!
component_spec
.
has_num_actions
())
{
component_spec
.
set_num_actions
(
kLogitsDim
);
}
component_spec
.
set_name
(
kTestComponentName
);
auto
*
fixed_feature
=
component_spec
.
add_fixed_feature
();
fixed_feature
->
set_embedding_dim
(
-
1
);
fixed_feature
->
set_size
(
1
);
TF_RETURN_IF_ERROR
(
AddFrozenGraphDefResource
(
flow_path
,
&
component_spec
));
component_spec
.
mutable_backend
()
->
set_registered_name
(
"SequenceBackend"
);
auto
&
parameters
=
*
component_spec
.
mutable_component_builder
()
->
mutable_parameters
();
parameters
[
"sequence_extractors"
]
=
"EvenNumbers"
;
parameters
[
"sequence_linkers"
]
=
""
;
parameters
[
"sequence_predictor"
]
=
"NoPredictions"
;
AddComponent
(
kTestComponentName
);
TF_RETURN_IF_ERROR
(
Component
::
CreateOrError
(
"SequenceXlaDynamicComponent"
,
&
component_
));
TF_RETURN_IF_ERROR
(
component_
->
Initialize
(
component_spec
,
&
variable_store_
,
&
network_state_manager_
,
&
extension_manager_
));
network_states_
.
Reset
(
&
network_state_manager_
);
StartComponent
(
0
);
// XlaDynamicComponent will add steps
session_state_
.
extensions
.
Reset
(
&
extension_manager_
);
TF_RETURN_IF_ERROR
(
component_
->
Evaluate
(
&
session_state_
,
&
compute_session_
,
component_trace
));
return
tensorflow
::
Status
::
OK
();
}
// Input batch injected into Evaluate() by default.
InputBatchCache
input_
;
// Backend injected into Evaluate().
SequenceBackend
backend_
;
std
::
unique_ptr
<
Component
>
component_
;
};
// Tests that XlaDynamicComponent fails if the spec uses attention.
TEST_F
(
SequenceXlaDynamicComponentMixinTest
,
UnsupportedAttention
)
{
EXPECT_THAT
(
Run
(
"attention_component:'foo'"
),
test
::
IsErrorWithSubstr
(
"Attention is not supported"
));
}
// Tests that XlaDynamicComponent fails if the spec has embedded fixed
// features.
TEST_F
(
SequenceXlaDynamicComponentMixinTest
,
InvalidFixedFeatureIsEmbedded
)
{
EXPECT_THAT
(
Run
(
"fixed_feature { embedding_dim:1 }"
),
test
::
IsErrorWithSubstr
(
"XLA requires non-embedded fixed features"
));
}
// Tests that XlaDynamicComponent fails if the ComponentSpec has a fixed
// feature that does not appear in the graph.
TEST_F
(
SequenceXlaDynamicComponentMixinTest
,
InvalidFixedFeatureNotInGraph
)
{
EXPECT_THAT
(
Run
(
"fixed_feature { embedding_dim:-1 size:1 }"
),
test
::
IsErrorWithSubstr
(
tensorflow
::
strings
::
StrCat
(
"No XLA tensor named '"
,
MakeXlaInputFixedFeatureIdName
(
1
,
0
),
"'"
)));
}
// Tests that XlaDynamicComponent fails if the spec has multipled linked
// features.
TEST_F
(
SequenceXlaDynamicComponentMixinTest
,
InvalidLinkedFeatureIsMultiplied
)
{
EXPECT_THAT
(
Run
(
"linked_feature { embedding_dim:1 }"
),
test
::
IsErrorWithSubstr
(
"XLA requires non-multiplied linked features"
));
}
// Tests that XlaDynamicComponent fails if the ComponentSpec has a linked
// feature that does not appear in the graph.
TEST_F
(
SequenceXlaDynamicComponentMixinTest
,
InvalidLinkedFeatureNotInGraph
)
{
const
string
kSpec
=
tensorflow
::
strings
::
StrCat
(
"linked_feature { source_component:'"
,
kTestComponentName
,
"' source_layer:'logits' embedding_dim:-1 size:1 }"
);
EXPECT_THAT
(
Run
(
kSpec
),
test
::
IsErrorWithSubstr
(
tensorflow
::
strings
::
StrCat
(
"No XLA tensor named '"
,
MakeXlaInputLinkedActivationVectorName
(
0
),
"'"
)));
}
// Tests that XlaDynamicComponent fails if the GraphDef file does not exist.
TEST_F
(
SequenceXlaDynamicComponentMixinTest
,
InvalidPath
)
{
EXPECT_THAT
(
Run
(
""
,
"/invalid/path"
),
test
::
IsErrorWithSubstr
(
"No such file or directory"
));
}
// Tests that XlaDynamicComponent fails if the logits dimension does not
// match ComponentSpec.num_actions.
TEST_F
(
SequenceXlaDynamicComponentMixinTest
,
WrongLogitsDimension
)
{
GraphDefOptions
options
;
options
.
logits_dim
=
kLogitsDim
+
1
;
EXPECT_THAT
(
Run
(
""
,
WriteFrozenGraphDef
(
options
)),
test
::
IsErrorWithSubstr
(
"Dimension mismatch between classification logits"
));
}
// Tests that XlaDynamicComponent fails if there is no "logits" layer.
TEST_F
(
SequenceXlaDynamicComponentMixinTest
,
WrongLogitsName
)
{
GraphDefOptions
options
;
options
.
logits_name
=
"not_logits"
;
EXPECT_THAT
(
Run
(
""
,
WriteFrozenGraphDef
(
options
)),
test
::
IsErrorWithSubstr
(
"Unknown layer 'logits'"
));
}
// Tests that XlaDynamicComponent fails to compile if one of the XLA
// tensors has the wrong type.
TEST_F
(
SequenceXlaDynamicComponentMixinTest
,
FailToCompile
)
{
GraphDefOptions
options
;
options
.
id_type
=
xla
::
F32
;
EXPECT_THAT
(
Run
(
""
,
WriteFrozenGraphDef
(
options
)),
test
::
IsErrorWithSubstr
(
"float is not in the list of allowed values"
));
}
// Tests that XlaDynamicComponent fails if one of the XLA tensors is not
// vector-like.
TEST_F
(
SequenceXlaDynamicComponentMixinTest
,
NotVectorLike
)
{
GraphDefOptions
options
;
options
.
id_dim
=
2
;
EXPECT_THAT
(
Run
(
""
,
WriteFrozenGraphDef
(
options
)),
test
::
IsErrorWithSubstr
(
"XLA tensor has non-vector-like shape"
));
}
// Tests that XlaDynamicComponent can run a simple non-deterministic frozen
// GraphDef.
TEST_F
(
SequenceXlaDynamicComponentMixinTest
,
SimpleNonDeterministicFlow
)
{
TF_ASSERT_OK
(
Run
());
const
Matrix
<
float
>
logits
(
GetLayer
(
kTestComponentName
,
"logits"
));
ASSERT_EQ
(
logits
.
num_rows
(),
kNumSteps
);
ASSERT_EQ
(
logits
.
num_columns
(),
kLogitsDim
);
// Since each row of the embedding matrix is filled with its index, the logits
// should be equal to the feature IDs.
for
(
int
step_index
=
0
;
step_index
<
kNumSteps
;
++
step_index
)
{
ExpectVector
(
logits
.
row
(
step_index
),
kLogitsDim
,
2
*
step_index
);
}
}
// Tests that XlaDynamicComponent can run a simple deterministic frozen
// GraphDef.
TEST_F
(
SequenceXlaDynamicComponentMixinTest
,
SimpleDeterministicFlow
)
{
GraphDefOptions
options
;
options
.
logits_dim
=
1
;
TF_ASSERT_OK
(
Run
(
"num_actions:1"
,
WriteFrozenGraphDef
(
options
)));
}
// Tests that XlaDynamicComponent can run a simple frozen GraphDef with tracing
// enabled.
TEST_F
(
SequenceXlaDynamicComponentMixinTest
,
SimpleFlowWithTracing
)
{
ComponentTrace
component_trace
;
TF_ASSERT_OK
(
Run
(
""
,
WriteFrozenGraphDef
(),
&
component_trace
));
// Each step trace should have a cell trace from the XLA instance.
ASSERT_EQ
(
component_trace
.
step_trace_size
(),
kNumSteps
);
for
(
const
ComponentStepTrace
&
step_trace
:
component_trace
.
step_trace
())
{
// TODO(googleuser): Add once the JIT API supports this.
EXPECT_EQ
(
step_trace
.
ExtensionSize
(
CellTrace
::
step_trace_extension
),
0
);
}
}
}
// namespace
}
// namespace runtime
}
// namespace dragnn
}
// namespace syntaxnet
research/syntaxnet/dragnn/runtime/xla/testdata/simple-component-spec
deleted
100644 → 0
View file @
a4bb31d0
name: "test_component"
fixed_feature {
embedding_dim: -1
size: 1
}
num_actions: 1
component_builder {
registered_name: "XlaAotDynamicComponent_model_v1_test_component"
}
[syntaxnet.dragnn.runtime.CompilationSpec.component_spec_extension] {
model_name: "model_v1"
cell_subgraph_spec {
input {
name: "fixed_channel_0_index_0_ids"
tensor: "cell/id:0"
type: TYPE_FEATURE
}
output {
name: "logits"
tensor: "cell/lookup:0"
}
}
}
research/syntaxnet/dragnn/runtime/xla/testdata/simple-config.pbtxt
deleted
100644 → 0
View file @
a4bb31d0
feed {
id {
node_name: "cell/id"
}
shape {
dim {
size: 1
}
}
name: "INPUT__fixed_channel_0_index_0_ids"
}
fetch {
id {
node_name: "cell/lookup"
}
name: "OUTPUT__logits"
}
research/syntaxnet/dragnn/runtime/xla/testdata/simple-graph.pbtxt
deleted
100644 → 0
View file @
a4bb31d0
node {
name: "CellSubgraphSpec"
op: "Const"
attr {
key: "dtype"
value {
type: DT_STRING
}
}
attr {
key: "shape"
value {
shape {
dim {
size: 1
}
}
}
}
attr {
key: "value"
value {
tensor {
dtype: DT_STRING
tensor_shape {
dim {
size: 1
}
}
string_val: "\n*\n\033fixed_channel_0_index_0_ids\022\tcell/id:0\030\001\022\027\n\006logits\022\rcell/lookup:0"
}
}
}
}
node {
name: "cell/id"
op: "Placeholder"
attr {
key: "dtype"
value {
type: DT_INT32
}
}
attr {
key: "shape"
value {
shape {
dim {
size: 1
}
}
}
}
}
node {
name: "cell/embedding_matrix"
op: "Const"
attr {
key: "dtype"
value {
type: DT_FLOAT
}
}
attr {
key: "value"
value {
tensor {
dtype: DT_FLOAT
tensor_shape {
dim {
size: 123
}
dim {
size: 1
}
}
tensor_content: "\000\000\000\000\000\000\200?\000\000\000@\000\000@@\000\000\200@\000\000\240@\000\000\300@\000\000\340@\000\000\000A\000\000\020A\000\000 A\000\0000A\000\000@A\000\000PA\000\000`A\000\000pA\000\000\200A\000\000\210A\000\000\220A\000\000\230A\000\000\240A\000\000\250A\000\000\260A\000\000\270A\000\000\300A\000\000\310A\000\000\320A\000\000\330A\000\000\340A\000\000\350A\000\000\360A\000\000\370A\000\000\000B\000\000\004B\000\000\010B\000\000\014B\000\000\020B\000\000\024B\000\000\030B\000\000\034B\000\000 B\000\000$B\000\000(B\000\000,B\000\0000B\000\0004B\000\0008B\000\000<B\000\000@B\000\000DB\000\000HB\000\000LB\000\000PB\000\000TB\000\000XB\000\000\\B\000\000`B\000\000dB\000\000hB\000\000lB\000\000pB\000\000tB\000\000xB\000\000|B\000\000\200B\000\000\202B\000\000\204B\000\000\206B\000\000\210B\000\000\212B\000\000\214B\000\000\216B\000\000\220B\000\000\222B\000\000\224B\000\000\226B\000\000\230B\000\000\232B\000\000\234B\000\000\236B\000\000\240B\000\000\242B\000\000\244B\000\000\246B\000\000\250B\000\000\252B\000\000\254B\000\000\256B\000\000\260B\000\000\262B\000\000\264B\000\000\266B\000\000\270B\000\000\272B\000\000\274B\000\000\276B\000\000\300B\000\000\302B\000\000\304B\000\000\306B\000\000\310B\000\000\312B\000\000\314B\000\000\316B\000\000\320B\000\000\322B\000\000\324B\000\000\326B\000\000\330B\000\000\332B\000\000\334B\000\000\336B\000\000\340B\000\000\342B\000\000\344B\000\000\346B\000\000\350B\000\000\352B\000\000\354B\000\000\356B\000\000\360B\000\000\362B\000\000\364B"
}
}
}
}
node {
name: "cell/lookup"
op: "Gather"
input: "cell/embedding_matrix"
input: "cell/id"
attr {
key: "Tindices"
value {
type: DT_INT32
}
}
attr {
key: "Tparams"
value {
type: DT_FLOAT
}
}
attr {
key: "validate_indices"
value {
b: true
}
}
}
research/syntaxnet/dragnn/runtime/xla/testdata/xla_compilation_output/master-spec
deleted
100644 → 0
View file @
a4bb31d0
component {
name: "rnn"
transition_system {
registered_name: "shift-only"
parameters {
key: "left_to_right"
value: "false"
}
parameters {
key: "parser_skip_deterministic"
value: "false"
}
}
resource {
name: "words-embedding-input"
part {
file_format: "tf-records"
record_format: "syntaxnet.TokenEmbedding"
}
}
resource {
name: "words-vocab-input"
part {
file_format: "text"
record_format: ""
}
}
resource {
name: "char-ngram-map"
part {
file_format: "text"
record_format: ""
}
}
resource {
name: "word-map"
part {
file_format: "text"
record_format: ""
}
}
resource {
name: "label-map"
part {
file_format: "text"
record_format: ""
}
}
resource {
name: "frozen-graph"
part {
file_format: "proto"
record_format: "tensorflow.GraphDef"
}
}
fixed_feature {
name: "char_ngrams"
fml: "input.token { offset(-1).char-ngram(min-length=1,max-length=3,mark-boundaries=true) offset(0).char-ngram(min-length=1,max-length=3,mark-boundaries=true) offset(1).char-ngram(min-length=1,max-length=3,mark-boundaries=true) }"
embedding_dim: -1
vocabulary_size: 25788
size: 3
}
fixed_feature {
name: "words"
fml: "input.token.word(min-freq=2)"
embedding_dim: -1
vocabulary_size: 23769
size: 1
}
network_unit {
registered_name: "LSTMNetwork"
parameters {
key: "hidden_layer_sizes"
value: "128"
}
parameters {
key: "omit_logits"
value: "true"
}
}
backend {
registered_name: "SyntaxNetComponent"
}
num_actions: 1
attention_component: ""
component_builder {
registered_name: "XlaDynamicComponent"
}
}
component {
name: "tagger"
transition_system {
registered_name: "tagger"
parameters {
key: "parser_skip_deterministic"
value: "false"
}
}
resource {
name: "tag-map"
part {
file_format: "text"
record_format: ""
}
}
resource {
name: "tag-to-category"
part {
file_format: "text"
record_format: ""
}
}
resource {
name: "label-map"
part {
file_format: "text"
record_format: ""
}
}
resource {
name: "frozen-graph"
part {
file_format: "proto"
record_format: "tensorflow.GraphDef"
}
}
linked_feature {
name: "recurrence"
fml: "bias(0)"
embedding_dim: -1
size: 1
source_component: "tagger"
source_translator: "history"
source_layer: "layer_0"
}
linked_feature {
name: "rnn"
fml: "input.focus"
embedding_dim: -1
size: 1
source_component: "rnn"
source_translator: "reverse-token"
source_layer: "layer_0"
}
network_unit {
registered_name: "FeedForwardNetwork"
parameters {
key: "hidden_layer_sizes"
value: "64,64"
}
}
backend {
registered_name: "SyntaxNetComponent"
}
num_actions: 45
attention_component: ""
component_builder {
registered_name: "XlaDynamicComponent"
}
}
research/syntaxnet/dragnn/runtime/xla/testdata/xla_compilation_output/master-spec-aot
deleted
100644 → 0
View file @
a4bb31d0
component {
name: "rnn"
transition_system {
registered_name: "shift-only"
parameters {
key: "left_to_right"
value: "false"
}
parameters {
key: "parser_skip_deterministic"
value: "false"
}
}
resource {
name: "words-embedding-input"
part {
file_format: "tf-records"
record_format: "syntaxnet.TokenEmbedding"
}
}
resource {
name: "words-vocab-input"
part {
file_format: "text"
record_format: ""
}
}
resource {
name: "char-ngram-map"
part {
file_format: "text"
record_format: ""
}
}
resource {
name: "word-map"
part {
file_format: "text"
record_format: ""
}
}
resource {
name: "label-map"
part {
file_format: "text"
record_format: ""
}
}
resource {
name: "frozen-graph"
part {
file_format: "proto"
record_format: "tensorflow.GraphDef"
}
}
fixed_feature {
name: "char_ngrams"
fml: "input.token { offset(-1).char-ngram(min-length=1,max-length=3,mark-boundaries=true) offset(0).char-ngram(min-length=1,max-length=3,mark-boundaries=true) offset(1).char-ngram(min-length=1,max-length=3,mark-boundaries=true) }"
embedding_dim: -1
vocabulary_size: 25788
size: 3
}
fixed_feature {
name: "words"
fml: "input.token.word(min-freq=2)"
embedding_dim: -1
vocabulary_size: 23769
size: 1
}
network_unit {
registered_name: "LSTMNetwork"
parameters {
key: "hidden_layer_sizes"
value: "128"
}
parameters {
key: "omit_logits"
value: "true"
}
}
backend {
registered_name: "SyntaxNetComponent"
}
num_actions: 1
attention_component: ""
component_builder {
registered_name: "XlaDynamicComponent"
}
[syntaxnet.dragnn.runtime.CompilationSpec.component_spec_extension] {
model_name: "model_v1"
cell_subgraph_spec {
input {
name: "fixed_channel_0_index_0_ids"
tensor: "rnn/INPUT/fixed_channel_0_index_0_ids:0"
type: TYPE_FEATURE
}
input {
name: "fixed_channel_0_index_1_ids"
tensor: "rnn/INPUT/fixed_channel_0_index_1_ids:0"
type: TYPE_FEATURE
}
input {
name: "fixed_channel_0_index_2_ids"
tensor: "rnn/INPUT/fixed_channel_0_index_2_ids:0"
type: TYPE_FEATURE
}
input {
name: "fixed_channel_1_index_0_ids"
tensor: "rnn/INPUT/fixed_channel_1_index_0_ids:0"
type: TYPE_FEATURE
}
input {
name: "lstm_c"
tensor: "rnn/INPUT/lstm_c:0"
type: TYPE_RECURRENT
}
input {
name: "lstm_h"
tensor: "rnn/INPUT/lstm_h:0"
type: TYPE_RECURRENT
}
output {
name: "lstm_h"
tensor: "annotation/inference_rnn/rnn/lstm_h:0"
}
output {
name: "lstm_c"
tensor: "annotation/inference_rnn/rnn/lstm_c:0"
}
output {
name: "layer_0"
tensor: "annotation/inference_rnn/rnn/layer_0:0"
}
output {
name: "logits"
tensor: "annotation/inference_rnn/rnn/logits:0"
}
}
}
}
component {
name: "tagger"
transition_system {
registered_name: "tagger"
parameters {
key: "parser_skip_deterministic"
value: "false"
}
}
resource {
name: "tag-map"
part {
file_format: "text"
record_format: ""
}
}
resource {
name: "tag-to-category"
part {
file_format: "text"
record_format: ""
}
}
resource {
name: "label-map"
part {
file_format: "text"
record_format: ""
}
}
resource {
name: "frozen-graph"
part {
file_format: "proto"
record_format: "tensorflow.GraphDef"
}
}
linked_feature {
name: "recurrence"
fml: "bias(0)"
embedding_dim: -1
size: 1
source_component: "tagger"
source_translator: "history"
source_layer: "layer_0"
}
linked_feature {
name: "rnn"
fml: "input.focus"
embedding_dim: -1
size: 1
source_component: "rnn"
source_translator: "reverse-token"
source_layer: "layer_0"
}
network_unit {
registered_name: "FeedForwardNetwork"
parameters {
key: "hidden_layer_sizes"
value: "64,64"
}
}
backend {
registered_name: "SyntaxNetComponent"
}
num_actions: 45
attention_component: ""
component_builder {
registered_name: "XlaDynamicComponent"
}
[syntaxnet.dragnn.runtime.CompilationSpec.component_spec_extension] {
model_name: "model_v1"
cell_subgraph_spec {
input {
name: "linked_channel_0_activations"
tensor: "tagger/INPUT/linked_channel_0_activations:0"
type: TYPE_FEATURE
}
input {
name: "linked_channel_0_out_of_bounds"
tensor: "tagger/INPUT/linked_channel_0_out_of_bounds:0"
type: TYPE_FEATURE
}
input {
name: "linked_channel_1_activations"
tensor: "tagger/INPUT/linked_channel_1_activations:0"
type: TYPE_FEATURE
}
output {
name: "layer_0"
tensor: "annotation/inference_tagger/tagger/Relu:0"
}
output {
name: "layer_1"
tensor: "annotation/inference_tagger/tagger/Relu_1:0"
}
output {
name: "last_layer"
tensor: "annotation/inference_tagger/tagger/Relu_1:0"
}
output {
name: "logits"
tensor: "annotation/inference_tagger/tagger/logits:0"
}
}
}
}
research/syntaxnet/dragnn/runtime/xla/testdata/xla_compilation_output/rnn-frozen
deleted
100644 → 0
View file @
a4bb31d0
File deleted
research/syntaxnet/dragnn/runtime/xla/testdata/xla_compilation_output/tagger-frozen
deleted
100644 → 0
View file @
a4bb31d0
File deleted
research/syntaxnet/dragnn/runtime/xla/xla_aot_dynamic_component.h
deleted
100644 → 0
View file @
a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#ifndef DRAGNN_RUNTIME_XLA_XLA_AOT_DYNAMIC_COMPONENT_H_
#define DRAGNN_RUNTIME_XLA_XLA_AOT_DYNAMIC_COMPONENT_H_
#include <string>
#include "dragnn/protos/export.pb.h"
#include "dragnn/protos/spec.pb.h"
#include "dragnn/runtime/component.h"
#include "dragnn/runtime/xla/sequence_xla_dynamic_component_mixin.h"
#include "dragnn/runtime/xla/xla_dynamic_component_base.h"
#include "dragnn/runtime/xla/xla_graph_utils.h"
#include "dragnn/runtime/xla/xla_spec_utils.h"
#include "syntaxnet/base.h"
#include "tensorflow/compiler/tf2xla/tf2xla.pb.h"
#include "tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/strings/strcat.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
// An XLA-based version of DynamicComponent using an XLA AOT compiled library.
//
// The class |AotCell| is generated by a tf_library build rule.
//
// The component class is instantiated in C++ code generated by a
// dragnn_xla_aot_components() build rule. The default constructor must set
// the model and component names to non-empty strings, and this must match
// the registered class name, as generated by RegisteredName().
//
// Example instantiation and registration:
//
// class XlaAotDynamicComponent_model_component
// : public XlaAotDynamicComponent<model::component> {
// public:
// XlaAotDynamicComponent_model_component()
// : XlaAotDynamicComponent<model::component>("model", "component") {}
// };
// DRAGNN_RUNTIME_REGISTER_COMPONENT(XlaAotDynamicComponent_model_component);
template
<
typename
AotCell
>
class
XlaAotDynamicComponent
:
public
XlaDynamicComponentBase
{
protected:
XlaAotDynamicComponent
(
const
string
&
model_name
,
const
string
&
component_name
)
:
model_name_
(
model_name
),
component_name_
(
component_name
)
{}
// Unlike other specializations, this component will only be active if the
// spec is explicitly modified to support XLA AOT.
bool
Supports
(
const
ComponentSpec
&
spec
,
const
string
&
normalized_builder_name
)
const
override
{
// This must accept both the "base" XLA component and this one, based on how
// Supports is called repeatedly.
return
(
normalized_builder_name
==
"XlaDynamicComponent"
||
normalized_builder_name
==
RegisteredName
())
&&
spec
.
name
()
==
component_name_
&&
ModelNameForComponent
(
spec
)
==
model_name_
&&
GetCellSubgraphSpecForComponent
(
spec
,
nullptr
).
ok
();
}
bool
PreferredTo
(
const
Component
&
other
)
const
override
{
// AOT is preferred to JIT.
return
true
;
}
// Gets the frozen GraphDef using the |component_spec| and compiles it.
// The |cell_subgraph_spec| contained within it is filled in. On error,
// returns non-OK.
tensorflow
::
Status
InitializeFromComponentSpec
(
const
ComponentSpec
&
component_spec
,
CellSubgraphSpec
*
cell_subgraph_spec
)
override
;
const
tensorflow
::
XlaCompiledCpuFunction
::
StaticData
&
XlaStaticData
()
const
override
{
return
AotCell
::
StaticData
();
}
private:
const
string
RegisteredName
()
const
{
return
tensorflow
::
strings
::
StrCat
(
"XlaAotDynamicComponent_"
,
model_name_
,
"_"
,
component_name_
);
}
const
string
model_name_
;
const
string
component_name_
;
};
template
<
typename
AotCell
>
tensorflow
::
Status
XlaAotDynamicComponent
<
AotCell
>::
InitializeFromComponentSpec
(
const
ComponentSpec
&
component_spec
,
CellSubgraphSpec
*
cell_subgraph_spec
)
{
LOG
(
INFO
)
<<
"Using XLA AOT library for model/component: "
<<
model_name_
<<
"/"
<<
component_name_
;
CHECK
(
!
model_name_
.
empty
()
&&
!
component_name_
.
empty
());
return
GetCellSubgraphSpecForComponent
(
component_spec
,
cell_subgraph_spec
);
}
// Sequence-based version of the above.
template
<
typename
AotCell
>
using
SequenceXlaAotDynamicComponent
=
SequenceXlaDynamicComponentMixin
<
XlaAotDynamicComponent
<
AotCell
>>
;
}
// namespace runtime
}
// namespace dragnn
}
// namespace syntaxnet
#endif // DRAGNN_RUNTIME_XLA_XLA_AOT_DYNAMIC_COMPONENT_H_
research/syntaxnet/dragnn/runtime/xla/xla_aot_dynamic_component_test.cc
deleted
100644 → 0
View file @
a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/xla/xla_aot_dynamic_component.h"
#include "dragnn/core/test/generic.h"
#include "dragnn/protos/export.pb.h"
#include "dragnn/protos/spec.pb.h"
#include "dragnn/runtime/test/network_test_base.h"
#include "tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/test.h"
using
::
testing
::
_
;
using
::
testing
::
InSequence
;
using
::
testing
::
Invoke
;
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
namespace
{
// Fake AOT class suitable for testing initialization.
class
TestComponent
{
public:
static
const
tensorflow
::
XlaCompiledCpuFunction
::
StaticData
&
StaticData
()
{
static
tensorflow
::
XlaCompiledCpuFunction
::
StaticData
*
kStaticData
=
new
tensorflow
::
XlaCompiledCpuFunction
::
StaticData
;
return
*
kStaticData
;
}
};
constexpr
char
kXlaModel
[]
=
"TestModel"
;
constexpr
char
kXlaComponent
[]
=
"TestComponent"
;
class
XlaAotDynamicComponent_TestModel_TestComponent
:
public
XlaAotDynamicComponent
<
TestComponent
>
{
public:
XlaAotDynamicComponent_TestModel_TestComponent
()
:
XlaAotDynamicComponent
<
TestComponent
>
(
kXlaModel
,
kXlaComponent
)
{}
using
XlaAotDynamicComponent
<
TestComponent
>::
Supports
;
using
XlaAotDynamicComponent
<
TestComponent
>::
InitializeFromComponentSpec
;
};
DRAGNN_RUNTIME_REGISTER_COMPONENT
(
XlaAotDynamicComponent_TestModel_TestComponent
);
class
XlaAotDynamicComponentTest
:
public
::
testing
::
Test
{
public:
// Test util that builds a ComponentSpec with |component_name| set (if
// non-empty). A CompilationSpec extension contains |model_name| (if
// non-empty) and an empty CellSubgraphSpec if |include_subgraph_spec| is
// true. No extension is added if |model_name| is empty and
// |include_subgraph_spec| is false.
ComponentSpec
BuildComponentSpec
(
const
string
&
model_name
,
const
string
&
component_name
,
bool
include_subgraph_spec
)
{
ComponentSpec
spec
;
if
(
!
component_name
.
empty
())
spec
.
set_name
(
component_name
);
// Add the extension if anything is in it.
if
(
!
model_name
.
empty
()
||
include_subgraph_spec
)
{
auto
*
compilation_spec
=
spec
.
MutableExtension
(
CompilationSpec
::
component_spec_extension
);
if
(
!
model_name
.
empty
())
compilation_spec
->
set_model_name
(
model_name
);
if
(
include_subgraph_spec
)
{
CellSubgraphSpec
cell_subgraph_spec
;
*
compilation_spec
->
mutable_cell_subgraph_spec
()
=
cell_subgraph_spec
;
}
}
return
spec
;
}
protected:
XlaAotDynamicComponent_TestModel_TestComponent
component_
;
};
TEST_F
(
XlaAotDynamicComponentTest
,
Supports
)
{
ComponentSpec
spec
=
BuildComponentSpec
(
kXlaModel
,
kXlaComponent
,
true
);
EXPECT_TRUE
(
component_
.
Supports
(
spec
,
"XlaDynamicComponent"
));
EXPECT_TRUE
(
component_
.
Supports
(
spec
,
"XlaAotDynamicComponent_TestModel_TestComponent"
));
EXPECT_FALSE
(
component_
.
Supports
(
spec
,
"DynamicComponent"
));
EXPECT_FALSE
(
component_
.
Supports
(
spec
,
"XlaAotDynamicComponent"
));
EXPECT_FALSE
(
component_
.
Supports
(
spec
,
"XlaAotDynamicComponent_TestModel_OtherComponent"
));
}
TEST_F
(
XlaAotDynamicComponentTest
,
SupportRequiresMatchingModelName
)
{
EXPECT_FALSE
(
component_
.
Supports
(
BuildComponentSpec
(
"OtherModel"
,
kXlaComponent
,
true
),
"XlaDynamicComponent"
));
EXPECT_FALSE
(
component_
.
Supports
(
BuildComponentSpec
(
""
,
kXlaComponent
,
true
),
"XlaDynamicComponent"
));
}
TEST_F
(
XlaAotDynamicComponentTest
,
SupportRequiresSubgraph
)
{
EXPECT_FALSE
(
component_
.
Supports
(
BuildComponentSpec
(
kXlaModel
,
kXlaComponent
,
false
),
"XlaDynamicComponent"
));
}
TEST_F
(
XlaAotDynamicComponentTest
,
InitializeFromComponentSpec
)
{
ComponentSpec
component_spec
;
auto
*
compilation_spec
=
component_spec
.
MutableExtension
(
CompilationSpec
::
component_spec_extension
);
// Example spec.
CellSubgraphSpec
expected_cell_subgraph_spec
;
auto
*
input
=
expected_cell_subgraph_spec
.
add_input
();
input
->
set_name
(
"fixed_channel_0_index_0_ids"
);
input
->
set_tensor
(
"cell/id:0"
);
input
->
set_type
(
CellSubgraphSpec
::
Input
::
TYPE_FEATURE
);
auto
*
output
=
expected_cell_subgraph_spec
.
add_output
();
output
->
set_name
(
"logits"
);
output
->
set_tensor
(
"cell/lookup:0"
);
*
compilation_spec
->
mutable_cell_subgraph_spec
()
=
expected_cell_subgraph_spec
;
CellSubgraphSpec
actual_cell_subgraph_spec
;
TF_ASSERT_OK
(
component_
.
InitializeFromComponentSpec
(
component_spec
,
&
actual_cell_subgraph_spec
));
EXPECT_THAT
(
actual_cell_subgraph_spec
,
test
::
EqualsProto
(
expected_cell_subgraph_spec
));
}
TEST_F
(
XlaAotDynamicComponentTest
,
InitializeFromComponentSpecNeedsSubgraph
)
{
CellSubgraphSpec
cell_subgraph_spec
;
TF_EXPECT_OK
(
component_
.
InitializeFromComponentSpec
(
BuildComponentSpec
(
kXlaModel
,
kXlaComponent
,
true
),
&
cell_subgraph_spec
));
EXPECT_THAT
(
component_
.
InitializeFromComponentSpec
(
BuildComponentSpec
(
kXlaModel
,
kXlaComponent
,
false
),
&
cell_subgraph_spec
),
test
::
IsErrorWithSubstr
(
"Component TestComponent does not have a CellSubgraphSpec"
));
}
// Tests using simple test AOT library.
constexpr
int
kNumSteps
=
50
;
constexpr
int
kVocabularySize
=
123
;
constexpr
char
kSimpleComponentSpecPath
[]
=
"dragnn/runtime/xla/testdata/simple-component-spec"
;
class
XlaAotDynamicComponentRunTest
:
public
NetworkTestBase
{
public:
// Creates a component, initializes it based on the |component_spec|,
// and evaluates it. On error, returns non-OK.
tensorflow
::
Status
Run
(
const
ComponentSpec
&
component_spec
)
{
AddComponent
(
kTestComponentName
);
TF_RETURN_IF_ERROR
(
Component
::
CreateOrError
(
"XlaAotDynamicComponent_model_v1_test_component"
,
&
component_
));
TF_RETURN_IF_ERROR
(
component_
->
Initialize
(
component_spec
,
&
variable_store_
,
&
network_state_manager_
,
&
extension_manager_
));
network_states_
.
Reset
(
&
network_state_manager_
);
StartComponent
(
0
);
session_state_
.
extensions
.
Reset
(
&
extension_manager_
);
TF_RETURN_IF_ERROR
(
component_
->
Evaluate
(
&
session_state_
,
&
compute_session_
,
nullptr
));
return
tensorflow
::
Status
::
OK
();
}
private:
std
::
unique_ptr
<
Component
>
component_
;
};
// Test that runs a simple deterministic component.
TEST_F
(
XlaAotDynamicComponentRunTest
,
Simple
)
{
SetupTransitionLoop
(
kNumSteps
);
EXPECT_CALL
(
compute_session_
,
AdvanceFromOracle
(
kTestComponentName
))
.
Times
(
kNumSteps
);
{
// Extract a sequence of feature IDs equal to 2 * step_index.
ASSERT_LE
(
2
*
kNumSteps
,
kVocabularySize
);
InSequence
scoped
;
for
(
int
step_index
=
0
;
step_index
<
kNumSteps
;
++
step_index
)
{
EXPECT_CALL
(
compute_session_
,
GetInputFeatures
(
_
,
_
,
_
,
_
,
_
))
.
WillOnce
(
Invoke
(
ExtractFeatures
(
0
,
{{
2
*
step_index
,
1.0
}})));
}
}
ComponentSpec
component_spec
;
TF_ASSERT_OK
(
tensorflow
::
ReadTextProto
(
tensorflow
::
Env
::
Default
(),
tensorflow
::
io
::
JoinPath
(
test
::
GetTestDataPrefix
(),
kSimpleComponentSpecPath
),
&
component_spec
));
TF_ASSERT_OK
(
Run
(
component_spec
));
}
}
// namespace
}
// namespace runtime
}
// namespace dragnn
}
// namespace syntaxnet
research/syntaxnet/dragnn/runtime/xla/xla_build_defs.bzl
deleted
100644 → 0
View file @
a4bb31d0
# Copyright 2017 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
"""Build extension rules for XLA AOT compilation."""
load
(
"//dragnn/runtime:multiarch.bzl"
,
"multiarch_name"
,
"MULTIARCH_CONFIGS"
,
)
load
(
"@org_tensorflow//tensorflow/compiler/aot:tfcompile.bzl"
,
"tf_library"
)
MULTIARCH_TFCOMPILE_FLAGS
=
{
"generic"
:
[],
"avx"
:
[
"--target_features=+avx,+sse4.2"
],
"avx2fma"
:
[
"--target_features=+avx,+avx2,+sse4.2,+fma"
],
}
def
_dragnn_xla_safe_name
(
name
):
"""Generates a version of |name| is safe for use in C++."""
return
name
.
replace
(
'-'
,
'_'
).
replace
(
'.'
,
'_'
)
def
_dragnn_xla_aot_library_name
(
arch
,
model
,
component
):
"""Returns the AOT library name for the given model/component."""
return
multiarch_name
(
model
+
'_'
+
component
,
arch
)
def
_dragnn_xla_aot_component_library_name
(
arch
,
model
,
component
):
"""Returns the AOT component library name for the given model/component."""
return
_dragnn_xla_aot_library_name
(
arch
,
model
,
component
)
+
'_component'
def
_dragnn_xla_config_proto
(
name
,
graph
,
config_tool
=
'//dragnn/runtime/xla:xla_extract_config'
):
"""Extracts XLA Config from a frozen GraphDef for a DRAGNN component.
Generates a build target called |name| which is a text file that contains
a tensorflow.tf2xla.Config used in a tf_library build rule. The output
file is called "<name>.pbtxt".
Args:
name: The name of the build rule.
graph: The frozen tensorflow.GraphDef binary proto built for a particular
DRAGNN component by the runtime.
config_tool: The binary used to extract the Config proto. A non-default
can be passed when necessary.
"""
config_path
=
name
+
'.pbtxt'
native
.
genrule
(
name
=
name
,
srcs
=
[
graph
],
outs
=
[
config_path
],
tools
=
[
config_tool
],
cmd
=
(
'$(location '
+
config_tool
+
')'
+
' $(location '
+
graph
+
')'
+
' $(location '
+
config_path
+
')'
)
)
def
_dragnn_xla_aot_component_cc_code
(
arch
,
model
,
component
,
target
):
"""Generates C++ code for a component which wraps a particular AOT library.
Returns a string containing the generated C++ code that defines and registers
the DRAGNN component the implements a particular |model| and |component|,
targeted to a the given |arch|. The class name and registry name do not
include |arch|, which means only one can be linked in.
Args:
arch: The name of the target architecture.
model: The name of the DRAGNN model.
component: The name of the DRAGNN component that uses XLA AOT.
target: The directory that contains XLA AOT target.
Returns:
The string containing the generated C++ code.
"""
cc_template
=
"""// GENERATED CODE.
#include "$TARGET/$MODEL_$COMPONENT_multiarch_$ARCH.h" // Generated by XLA.
#include "dragnn/runtime/xla/sequence_xla_dynamic_component_mixin.h"
#include "dragnn/runtime/xla/xla_aot_dynamic_component.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
namespace {
class XlaAotDynamicComponent_$MODEL_$COMPONENT
: public XlaAotDynamicComponent<$MODEL::$COMPONENT> {
public:
XlaAotDynamicComponent_$MODEL_$COMPONENT()
: XlaAotDynamicComponent<$MODEL::$COMPONENT>("$MODEL", "$COMPONENT") {}
};
DRAGNN_RUNTIME_REGISTER_COMPONENT(XlaAotDynamicComponent_$MODEL_$COMPONENT);
using SequenceXlaAotDynamicComponent_$MODEL_$COMPONENT =
SequenceXlaDynamicComponentMixin<XlaAotDynamicComponent_$MODEL_$COMPONENT>;
DRAGNN_RUNTIME_REGISTER_COMPONENT(
SequenceXlaAotDynamicComponent_$MODEL_$COMPONENT);
} // namespace
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
"""
return
cc_template
.
replace
(
'$ARCH'
,
arch
).
replace
(
'$TARGET'
,
target
).
replace
(
'$MODEL'
,
model
).
replace
(
'$COMPONENT'
,
component
)
def
_dragnn_xla_aot_component_library
(
arch
,
model
,
component
,
tags
=
None
,
testonly
=
0
):
"""Generates and compiles the component library that wraps the AOT binary.
Args:
arch: The name of the target architecture.
model: The name of the DRAGNN model.
component: The name of the DRAGNN component that uses XLA AOT.
tags: tags to apply to subsidiary build rules.
testonly: If 1, only testonly targets can depend on this target.
"""
xla_aot_library
=
_dragnn_xla_aot_library_name
(
arch
,
model
,
component
)
xla_aot_component_library
=
_dragnn_xla_aot_component_library_name
(
arch
,
model
,
component
)
xla_aot_component_src
=
xla_aot_component_library
+
'.cc'
native
.
genrule
(
name
=
xla_aot_component_library
+
'_cc'
,
outs
=
[
xla_aot_component_src
],
cmd
=
"cat << 'EOF' >$@
\n
{}
\n
EOF
\n
"
.
format
(
_dragnn_xla_aot_component_cc_code
(
arch
,
model
,
component
,
native
.
package_name
())
),
tags
=
tags
,
testonly
=
testonly
,
)
native
.
cc_library
(
name
=
xla_aot_component_library
,
srcs
=
[
xla_aot_component_src
],
deps
=
[
multiarch_name
(
'//dragnn/runtime/xla:sequence_xla_dynamic_component_mixin'
,
arch
),
multiarch_name
(
'//dragnn/runtime/xla:xla_aot_dynamic_component'
,
arch
),
':'
+
xla_aot_library
,
],
testonly
=
testonly
,
alwayslink
=
1
)
def
_dragnn_xla_aot_library
(
name
,
arch
,
model
,
component
,
graph
,
tags
=
None
,
testonly
=
0
):
"""Runs tfcompile to AOT-compile a frozen GraphDef for a DRAGNN component.
Generates a build target called |name| which is a cc_library containing
the generated header and AOT-compiled function that implements a specific
DRAGNN component. For details on compilation see:
@org_tensorflow//tensorflow/compiler/aot/tfcompile.bzl
The generated library contains the following C++ class:
syntaxnet::dragnn::runtime::<model>::<component>
and the output file is called <name>.h
There is also build target called <name>-config which contains the
Config proto used by XLA.
Args:
name: The name of the build rule.
arch: The name of the target architecture.
model: The name of the DRAGNN model that contains this component.
component: The name of the DRAGNN component in the ComponentSpec.
graph: The frozen tensorflow.GraphDef binary proto built for a particular
DRAGNN component by the runtime.
tags: tags to apply to subsidiary build rules.
testonly: If 1, only testonly targets can depend on this target.
"""
# Gets the Config proto needed by tfcompile.
xla_config_name
=
name
+
'-config'
_dragnn_xla_config_proto
(
name
=
xla_config_name
,
graph
=
graph
)
# Runs tfcompile to AOT-compile the GraphDef.
tf_library
(
name
=
_dragnn_xla_aot_library_name
(
arch
,
model
,
component
),
graph
=
graph
,
config
=
xla_config_name
,
cpp_class
=
'syntaxnet::dragnn::runtime::'
+
model
+
'::'
+
component
,
tfcompile_flags
=
' '
.
join
([
'--gen_name_to_index=true'
,
'--gen_program_shape=true'
,
'--xla_cpu_multi_thread_eigen=false'
,
]
+
MULTIARCH_TFCOMPILE_FLAGS
[
arch
]),
tags
=
tags
,
testonly
=
testonly
,
)
# Generates the component library that wraps the AOT library.
_dragnn_xla_aot_component_library
(
arch
,
model
,
component
,
tags
,
testonly
)
def
dragnn_xla_aot_components
(
name
,
component_data
,
tags
=
None
,
testonly
=
0
):
"""Generates targets for all XLA AOT components in |component_data|.
Every element in the list |component_data| is also a list, which contains:
- name of the DRAGNN model;
- name of the component;
- relative path to the frozen GraphDef proto.
If multiple models exist in the same binary, the model name must uniquely
identify this specific model instance, e.g. 'parser_v20171101'.
Args:
name: The name of the build rule.
component_data: A list of per-component-data that is necessary to build
the AOT library and the component that wraps it.
tags: tags to apply to subsidiary build rules; the arch-specific tags
are included.
testonly: If 1, only testonly targets can depend on this target.
"""
safe_component_data
=
[
[
_dragnn_xla_safe_name
(
model
),
_dragnn_xla_safe_name
(
component
),
graph
]
for
[
model
,
component
,
graph
]
in
component_data
]
# Generates the AOT library and component targets.
for
arch
in
MULTIARCH_TFCOMPILE_FLAGS
:
for
[
model
,
component
,
graph_path
]
in
safe_component_data
:
_dragnn_xla_aot_library
(
name
=
_dragnn_xla_aot_library_name
(
arch
,
model
,
component
),
arch
=
arch
,
model
=
model
,
component
=
component
,
graph
=
graph_path
,
tags
=
(
tags
if
tags
else
[])
+
MULTIARCH_CONFIGS
[
arch
][
'tags'
],
testonly
=
testonly
,
)
# Composes a library with all of the AOT library and component targets.
for
arch
in
MULTIARCH_TFCOMPILE_FLAGS
:
native
.
cc_library
(
name
=
multiarch_name
(
name
,
arch
),
deps
=
[
':'
+
_dragnn_xla_aot_component_library_name
(
arch
,
model
,
component
)
for
[
model
,
component
,
_
]
in
safe_component_data
],
tags
=
(
tags
if
tags
else
[])
+
MULTIARCH_CONFIGS
[
arch
][
'tags'
],
testonly
=
testonly
,
)
def
dragnn_xla_aot_bazel_test
(
name
,
srcs
):
"""Verifies that generated bzl matches what is checked in.
Passes when the generated file <name>_gen.bzl and the currently
existing one in <name>.bzl match.
Args:
name: The name of the bzl to test (without .bzl)
srcs: A set of MasterSpec files
"""
generated_bzl
=
name
+
'-gen.bzl'
native
.
genrule
(
name
=
name
+
'_gen'
,
outs
=
[
generated_bzl
],
cmd
=
(
'$(location '
+
'//dragnn/runtime/xla:xla_extract_names_from_specs) '
+
native
.
package_name
()
+
' $(SRCS) $(OUTS)'
),
tools
=
[
'//dragnn/runtime/xla:xla_extract_names_from_specs'
],
srcs
=
srcs
)
# Makes a copy of file_diff_test in this package.
native
.
genrule
(
name
=
'repackage_file_diff_test'
,
srcs
=
[
'//dragnn/python:file_diff_test.py'
],
outs
=
[
'%s/file_diff_test.py'
%
native
.
package_name
()],
cmd
=
'cp $< $@'
,
)
# Compare the generated file.
expected_bzl
=
name
+
'.bzl'
native
.
py_test
(
name
=
name
,
srcs
=
[
'%s/file_diff_test.py'
%
native
.
package_name
()],
main
=
'%s/file_diff_test.py'
%
native
.
package_name
(),
deps
=
[
'//dragnn/python:file_diff_test'
],
args
=
[
'--actual_file=$(location '
+
generated_bzl
+
')'
,
'--expected_file=$(location '
+
expected_bzl
+
')'
,
],
data
=
[
expected_bzl
,
generated_bzl
],
)
research/syntaxnet/dragnn/runtime/xla/xla_cell_converter.cc
deleted
100644 → 0
View file @
a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/xla/xla_cell_converter.h"
#include <vector>
#include "dragnn/runtime/xla/xla_graph_utils.h"
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/strings/numbers.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/logging.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
namespace
{
// Returns true if the |tensor_name| denotes a control dependency.
bool
IsControlDependency
(
const
string
&
tensor_name
)
{
return
tensor_name
[
0
]
==
'^'
;
}
// Returns the name of the node that supplies the input called |input_name|.
// This strips off any prefix on control dependencies and any suffix
// for specifying tensor output.
const
string
GetNodeNameFromInput
(
const
string
&
input_name
)
{
return
input_name
.
substr
(
IsControlDependency
(
input_name
)
?
1
:
0
,
input_name
.
rfind
(
':'
));
}
// Returns true if the |node| is a TF variable.
bool
IsVariableNode
(
const
tensorflow
::
NodeDef
&
node
)
{
return
node
.
op
()
==
"VariableV2"
;
}
// Returns true if the |node| is skippable and can be changed
// to an Identity node.
bool
IsNodeConvertibleToIdentity
(
const
tensorflow
::
NodeDef
&
node
)
{
return
node
.
op
()
==
"Enter"
;
}
// Returns true if the node attribute with |name| is one that should always be
// retained, when a node is being simplified or frozen.
bool
AlwaysKeepAttribute
(
const
string
&
name
)
{
return
name
==
"_output_shapes"
||
name
==
"T"
||
name
==
"dtype"
;
}
// Generates the name of the node that contains the serialized CellSubgraphSpec
// given a particular |component_name|.
string
MakeCellSubgraphSpecNodeName
(
const
string
&
component_name
)
{
return
tensorflow
::
strings
::
StrCat
(
component_name
,
"/EXPORT/CellSubgraphSpec"
);
}
// Loads the CellSubgraphSpec for the component named |component_name| from the
// |trained_model| into the |spec|. On error, returns non-OK.
tensorflow
::
Status
LoadCellSubgraphSpec
(
const
string
&
component_name
,
const
TrainedModel
&
trained_model
,
CellSubgraphSpec
*
spec
)
{
const
string
tensor_name
=
MakeCellSubgraphSpecNodeName
(
component_name
);
tensorflow
::
Tensor
tensor
;
TF_RETURN_IF_ERROR
(
trained_model
.
EvaluateTensor
(
tensor_name
,
&
tensor
));
if
(
!
spec
->
ParseFromString
(
tensor
.
scalar
<
string
>
()()))
{
return
tensorflow
::
errors
::
InvalidArgument
(
"Failed to parse CellSubgraphSpec for component "
,
component_name
);
}
VLOG
(
1
)
<<
tensor_name
<<
" =
\n
"
<<
spec
->
DebugString
();
return
tensorflow
::
Status
::
OK
();
}
}
// namespace
tensorflow
::
Status
XlaCellConverter
::
FillNode
(
const
tensorflow
::
NodeDef
&
src_node
,
tensorflow
::
NodeDef
*
dest_node
)
const
{
dest_node
->
set_name
(
src_node
.
name
());
dest_node
->
set_device
(
src_node
.
device
());
if
(
IsNodeConvertibleToIdentity
(
src_node
))
{
dest_node
->
set_op
(
"Identity"
);
FillNodeAttributes
(
true
,
src_node
,
dest_node
);
}
else
{
dest_node
->
set_op
(
src_node
.
op
());
FillNodeAttributes
(
false
,
src_node
,
dest_node
);
}
for
(
const
string
&
input
:
src_node
.
input
())
{
if
(
IsNodeInSubgraph
(
GetNodeNameFromInput
(
input
)))
{
dest_node
->
add_input
(
input
);
}
}
return
tensorflow
::
Status
::
OK
();
}
tensorflow
::
Status
XlaCellConverter
::
FreezeSpecNode
(
const
tensorflow
::
NodeDef
&
src_node
,
tensorflow
::
NodeDef
*
dest_node
)
const
{
dest_node
->
set_name
(
kFrozenCellSubgraphSpecNodeName
);
dest_node
->
set_op
(
"Const"
);
FillNodeAttributes
(
true
,
src_node
,
dest_node
);
tensorflow
::
Tensor
tensor
;
TF_RETURN_IF_ERROR
(
trained_model_
->
EvaluateTensor
(
AsVariableName
(
TensorId
(
src_node
.
name
(),
0
)),
&
tensor
));
// Leaves constants directly accessible, which allows for simple
// extraction of the value.
tensor
.
AsProtoField
((
*
dest_node
->
mutable_attr
())[
"value"
].
mutable_tensor
());
return
tensorflow
::
Status
::
OK
();
}
tensorflow
::
Status
XlaCellConverter
::
FreezeNode
(
const
tensorflow
::
NodeDef
&
src_node
,
tensorflow
::
NodeDef
*
dest_node
)
const
{
dest_node
->
set_name
(
src_node
.
name
());
dest_node
->
set_op
(
"Const"
);
FillNodeAttributes
(
true
,
src_node
,
dest_node
);
tensorflow
::
Tensor
tensor
;
TF_RETURN_IF_ERROR
(
trained_model_
->
EvaluateTensor
(
AsVariableName
(
TensorId
(
src_node
.
name
(),
0
)),
&
tensor
));
// Compactly stores tensor constants.
tensor
.
AsProtoTensorContent
(
(
*
dest_node
->
mutable_attr
())[
"value"
].
mutable_tensor
());
return
tensorflow
::
Status
::
OK
();
}
void
XlaCellConverter
::
FillNodeAttributes
(
bool
restrict_attributes
,
const
tensorflow
::
NodeDef
&
src_node
,
tensorflow
::
NodeDef
*
dest_node
)
{
for
(
const
auto
&
attr
:
src_node
.
attr
())
{
if
(
!
restrict_attributes
||
AlwaysKeepAttribute
(
attr
.
first
))
{
(
*
dest_node
->
mutable_attr
())[
attr
.
first
]
=
attr
.
second
;
}
}
}
bool
XlaCellConverter
::
IsNodeInSubgraph
(
const
string
&
node_name
)
const
{
return
operations_
.
find
(
node_name
)
!=
operations_
.
end
();
}
tensorflow
::
Status
XlaCellConverter
::
Convert
(
const
string
&
component_name
,
const
TrainedModel
&
trained_model
,
tensorflow
::
GraphDef
*
graph
,
CellSubgraphSpec
*
spec
)
{
return
XlaCellConverter
().
ConvertImpl
(
component_name
,
trained_model
,
graph
,
spec
);
}
tensorflow
::
Status
XlaCellConverter
::
ConvertImpl
(
const
string
&
component_name
,
const
TrainedModel
&
trained_model
,
tensorflow
::
GraphDef
*
graph
,
CellSubgraphSpec
*
spec
)
{
component_name_
=
component_name
;
trained_model_
=
&
trained_model
;
TF_RETURN_IF_ERROR
(
LoadCellSubgraphSpec
(
component_name_
,
*
trained_model_
,
spec
));
TF_RETURN_IF_ERROR
(
BuildInputsAndOutputs
(
*
spec
));
TF_RETURN_IF_ERROR
(
BuildOperations
());
graph
->
Clear
();
const
tensorflow
::
GraphDef
*
input_graph
;
TF_RETURN_IF_ERROR
(
trained_model_
->
GraphDef
(
&
input_graph
));
// Adds in the CellSubgraphSpec node for this component.
const
tensorflow
::
NodeDef
*
cell_subgraph_spec_node
=
nullptr
;
TF_RETURN_IF_ERROR
(
trained_model_
->
LookupNode
(
MakeCellSubgraphSpecNodeName
(
component_name_
),
&
cell_subgraph_spec_node
));
TF_RETURN_IF_ERROR
(
FreezeSpecNode
(
*
cell_subgraph_spec_node
,
graph
->
add_node
()));
// Adds in frozen versions of the nodes needed for this cell.
for
(
const
tensorflow
::
NodeDef
&
node
:
input_graph
->
node
())
{
if
(
IsNodeInSubgraph
(
node
.
name
()))
{
if
(
IsVariableNode
(
node
))
{
TF_RETURN_IF_ERROR
(
FreezeNode
(
node
,
graph
->
add_node
()));
}
else
{
TF_RETURN_IF_ERROR
(
FillNode
(
node
,
graph
->
add_node
()));
}
}
}
return
tensorflow
::
Status
::
OK
();
}
tensorflow
::
Status
XlaCellConverter
::
BuildInputsAndOutputs
(
const
CellSubgraphSpec
&
spec
)
{
std
::
set
<
string
>
unique_input_names
;
for
(
const
CellSubgraphSpec
::
Input
&
input
:
spec
.
input
())
{
if
(
!
unique_input_names
.
insert
(
input
.
name
()).
second
)
{
return
tensorflow
::
errors
::
InvalidArgument
(
"Duplicate input name { "
,
input
.
ShortDebugString
(),
" }"
);
}
TensorId
tensor_id
;
TF_RETURN_IF_ERROR
(
ParseTensorId
(
input
.
tensor
(),
&
tensor_id
));
if
(
!
inputs_
.
insert
(
tensor_id
).
second
)
{
return
tensorflow
::
errors
::
InvalidArgument
(
"Duplicate input variable { "
,
input
.
ShortDebugString
(),
" }"
);
}
}
std
::
set
<
string
>
unique_output_names
;
for
(
const
CellSubgraphSpec
::
Output
&
output
:
spec
.
output
())
{
if
(
!
unique_output_names
.
insert
(
output
.
name
()).
second
)
{
return
tensorflow
::
errors
::
InvalidArgument
(
"Duplicate output name { "
,
output
.
ShortDebugString
(),
" }"
);
}
TensorId
tensor_id
;
TF_RETURN_IF_ERROR
(
ParseTensorId
(
output
.
tensor
(),
&
tensor_id
));
outputs_
.
insert
(
tensor_id
);
}
// Check that recurrent inputs match the name of an output.
for
(
const
CellSubgraphSpec
::
Input
&
input
:
spec
.
input
())
{
if
(
input
.
type
()
!=
CellSubgraphSpec
::
Input
::
TYPE_RECURRENT
)
continue
;
if
(
unique_output_names
.
find
(
input
.
name
())
==
unique_output_names
.
end
())
{
return
tensorflow
::
errors
::
InvalidArgument
(
"Recurrent input does not match any output { "
,
input
.
ShortDebugString
(),
" }"
);
}
}
return
tensorflow
::
Status
::
OK
();
}
tensorflow
::
Status
XlaCellConverter
::
BuildOperations
()
{
// Extract sets of input and output node names.
std
::
set
<
string
>
input_node_names
;
std
::
set
<
string
>
output_node_names
;
for
(
const
TensorId
&
id
:
inputs_
)
input_node_names
.
insert
(
id
.
first
);
for
(
const
TensorId
&
id
:
outputs_
)
output_node_names
.
insert
(
id
.
first
);
// Set of nodes that have already been visited by the DFS.
std
::
set
<
string
>
visited
;
// DFS backwards from output nodes to input nodes and collect operations.
std
::
vector
<
string
>
stack
(
output_node_names
.
begin
(),
output_node_names
.
end
());
while
(
!
stack
.
empty
())
{
const
string
name
=
stack
.
back
();
stack
.
pop_back
();
if
(
!
visited
.
insert
(
name
).
second
)
continue
;
// already visited; skip
const
tensorflow
::
NodeDef
*
node
=
nullptr
;
TF_RETURN_IF_ERROR
(
trained_model_
->
LookupNode
(
name
,
&
node
));
Operation
&
operation
=
operations_
[
name
];
if
(
operation
.
node
!=
nullptr
&&
operation
.
node
!=
node
)
{
return
tensorflow
::
errors
::
Internal
(
"Inconsistent nodes for operation "
,
name
,
" ("
,
operation
.
node
->
name
(),
" vs "
,
node
->
name
());
}
operation
.
node
=
node
;
// Function inputs bound the search; don't expand them.
if
(
input_node_names
.
find
(
name
)
!=
input_node_names
.
end
())
continue
;
// Expand (non-control) inputs.
for
(
const
string
&
input_name
:
node
->
input
())
{
if
(
IsControlDependency
(
input_name
))
continue
;
VLOG
(
1
)
<<
name
<<
" has input "
<<
input_name
;
TensorId
tensor_id
;
TF_RETURN_IF_ERROR
(
ParseTensorId
(
input_name
,
&
tensor_id
));
stack
.
push_back
(
tensor_id
.
first
);
}
}
return
tensorflow
::
Status
::
OK
();
}
tensorflow
::
Status
XlaCellConverter
::
ParseTensorId
(
const
string
&
tensor_name
,
TensorId
*
tensor_id
)
{
return
ParseTensorName
(
tensor_name
,
&
tensor_id
->
first
,
&
tensor_id
->
second
);
}
string
XlaCellConverter
::
AsVariableName
(
const
TensorId
&
tensor_id
)
{
if
(
tensor_id
.
second
==
0
)
return
tensor_id
.
first
;
return
tensorflow
::
strings
::
StrCat
(
tensor_id
.
first
,
":"
,
tensor_id
.
second
);
}
}
// namespace runtime
}
// namespace dragnn
}
// namespace syntaxnet
research/syntaxnet/dragnn/runtime/xla/xla_cell_converter.h
deleted
100644 → 0
View file @
a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#ifndef DRAGNN_RUNTIME_XLA_XLA_CELL_CONVERTER_H_
#define DRAGNN_RUNTIME_XLA_XLA_CELL_CONVERTER_H_
#include <map>
#include <set>
#include <string>
#include <utility>
#include "dragnn/protos/export.pb.h"
#include "dragnn/runtime/trained_model.h"
#include "syntaxnet/base.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/types.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
// Converter that extracts the cell computation from a DRAGNN component and
// writes it as a frozen TF GraphDef.
//
// The trained model that contains the DRAGNN component must also contain a
// CellSubgraphSpec proto embedded into the TF graph as a specifically-named
// constant node (see runtime_support.py). The CellSubgraphSpec defines the
// boundaries of the cell comptation.
//
// Each frozen GraphDef contains a single function that runs the cell and
// is named after the component. The function inputs are reference
// variables, so they can be pointed at externally-managed pieces of memory,
// provided sufficient size and alignment. Output storage is managed by XLA.
// The function inputs and outputs are marked with special names, namely:
// INPUT__<CellSubgraphSpec.Input.name>
// OUTPUT__<CellSubgraphSpec.Output.name>
class
XlaCellConverter
{
public:
// Extracts the cell of the DRAGNN component named |component_name| from the
// |trained_model| and overwrites the |graph| with an equivalent
// TF GraphDef in |graph| which is frozen (it encapsulates Variables). The
// CellSubgraphSpec stored in the graph is copied into |spec|. On error,
// returns non-OK.
static
tensorflow
::
Status
Convert
(
const
string
&
component_name
,
const
TrainedModel
&
trained_model
,
tensorflow
::
GraphDef
*
graph
,
CellSubgraphSpec
*
spec
);
private:
// A (node_name, output_index) pair denoting a tensor.
using
TensorId
=
std
::
pair
<
string
,
uint32
>
;
// A TF operation that makes up the cell.
struct
Operation
{
// The TF graph node represented by this operation.
const
tensorflow
::
NodeDef
*
node
=
nullptr
;
};
// Creates an empty converter.
XlaCellConverter
()
=
default
;
// Populates |dest_node| with the contents of |src_node|. For most nodes
// this is a complete copy. The exception is for nodes converted to Identity
// ops (e.g. Enter nodes). In this case, the op is changed to "Identity" and
// only critical attributes (for tensor type and shape) are retained.
tensorflow
::
Status
FillNode
(
const
tensorflow
::
NodeDef
&
src_node
,
tensorflow
::
NodeDef
*
dest_node
)
const
;
// Populates |dest_node| with the frozen contents of |src_node| which
// evaluates to a CellSubgraphSpec. The serialized contents will be
// stored in the value.tensor.string_val which makes extraction and
// development cleaner.
tensorflow
::
Status
FreezeSpecNode
(
const
tensorflow
::
NodeDef
&
src_node
,
tensorflow
::
NodeDef
*
dest_node
)
const
;
// Populates |dest_node| with the frozen contents of |src_node|. The
// output tensor for |src_node| will be evaluated and included as a
// constant in |dest_node|. On error, returns non-OK.
tensorflow
::
Status
FreezeNode
(
const
tensorflow
::
NodeDef
&
src_node
,
tensorflow
::
NodeDef
*
dest_node
)
const
;
// Copies over node attributes from |src_node| to |dest_node|, stripping out
// those which don't apply generally when |restrict_attributes| is true.
static
void
FillNodeAttributes
(
bool
restrict_attributes
,
const
tensorflow
::
NodeDef
&
src_node
,
tensorflow
::
NodeDef
*
dest_node
);
// Returns true if a node called |node_name| is in the subgraph required
// for evaluating the cell.
bool
IsNodeInSubgraph
(
const
string
&
node_name
)
const
;
// Implements the static Convert() method.
tensorflow
::
Status
ConvertImpl
(
const
string
&
component_name
,
const
TrainedModel
&
trained_model
,
tensorflow
::
GraphDef
*
graph
,
CellSubgraphSpec
*
spec
);
// Populates the |inputs_| and |outputs_| based on the |spec|. On error,
// returns non-OK.
tensorflow
::
Status
BuildInputsAndOutputs
(
const
CellSubgraphSpec
&
spec
);
// Walks from the |outputs_| to the |inputs_| in the |trained_model_|, adding
// to |operations_| along the way. Requires that BuildInputsAndOutputs() was
// called. On error, returns non-OK.
tensorflow
::
Status
BuildOperations
();
// Parses a |tensor_name| into a |tensor_id|. E.g.,
// "foo/bar:1" => ("foo/bar", 1)
// "baz" => ("baz", 0)
// On error, returns non-OK. It is an error if the |tensor_name| denotes a
// control dependency.
static
tensorflow
::
Status
ParseTensorId
(
const
string
&
tensor_name
,
TensorId
*
tensor_id
);
// Returns the canonically-formatted name of the graph variable associated
// with the |tensor_id|.
static
string
AsVariableName
(
const
TensorId
&
tensor_id
);
// Name of the component being converted.
string
component_name_
;
// Trained model that contains the DRAGNN model.
const
TrainedModel
*
trained_model_
=
nullptr
;
// Tensor ids that serve as inputs and outputs.
std
::
set
<
TensorId
>
inputs_
;
std
::
set
<
TensorId
>
outputs_
;
// Mapping from node name to Operation.
std
::
map
<
string
,
Operation
>
operations_
;
};
}
// namespace runtime
}
// namespace dragnn
}
// namespace syntaxnet
#endif // DRAGNN_RUNTIME_XLA_XLA_CELL_CONVERTER_H_
research/syntaxnet/dragnn/runtime/xla/xla_cell_converter_test.cc
deleted
100644 → 0
View file @
a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/xla/xla_cell_converter.h"
#include <string.h>
#include <iostream>
#include <string>
#include <utility>
#include <vector>
#include "dragnn/core/test/generic.h"
#include "dragnn/protos/export.pb.h"
#include "dragnn/runtime/alignment.h"
#include "dragnn/runtime/trained_model.h"
#include "dragnn/runtime/xla/xla_graph_utils.h"
#include "dragnn/runtime/xla/xla_spec_utils.h"
#include "syntaxnet/base.h"
#include "tensorflow/compiler/tf2xla/tf2xla.pb.h"
#include "tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/lib/strings/numbers.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/test.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
namespace
{
// Relative path to a saved model.
constexpr
char
kSavedModelDir
[]
=
"dragnn/runtime/testdata/rnn_tagger"
;
// Names of components in the saved model.
const
char
*
kComponentNames
[]
=
{
"rnn"
,
"tagger"
};
// Returns a valid saved model directory.
string
GetSavedModelDir
()
{
return
tensorflow
::
io
::
JoinPath
(
test
::
GetTestDataPrefix
(),
kSavedModelDir
);
}
// Loads a trained model, converts each component to a frozen graph,
// compiles, and then runs the cell.
TEST
(
XlaCellConverterTest
,
LoadAndConvertAndRun
)
{
TrainedModel
trained_model
;
TF_ASSERT_OK
(
trained_model
.
Reset
(
GetSavedModelDir
()));
for
(
const
string
component_name
:
kComponentNames
)
{
LOG
(
INFO
)
<<
"Component: "
<<
component_name
;
// Freezes the graph.
tensorflow
::
GraphDef
graph_def
;
CellSubgraphSpec
spec_from_convert
;
TF_ASSERT_OK
(
XlaCellConverter
::
Convert
(
component_name
,
trained_model
,
&
graph_def
,
&
spec_from_convert
));
LOG
(
INFO
)
<<
component_name
<<
" graph nodes = "
<<
graph_def
.
node_size
();
// Extracts the CellSubgraphSpec and Config, then compiles.
CellSubgraphSpec
cell_subgraph_spec
;
tensorflow
::
tf2xla
::
Config
xla_config
;
TF_ASSERT_OK
(
GetSpecAndMakeXlaConfig
(
graph_def
,
&
cell_subgraph_spec
,
&
xla_config
));
EXPECT_THAT
(
cell_subgraph_spec
,
test
::
EqualsProto
(
spec_from_convert
));
LOG
(
INFO
)
<<
component_name
<<
" CellSubgraphSpec = "
<<
cell_subgraph_spec
.
DebugString
();
LOG
(
INFO
)
<<
component_name
<<
" Config = "
<<
xla_config
.
DebugString
();
TF_ASSERT_OK_AND_ASSIGN
(
std
::
unique_ptr
<
tensorflow
::
XlaJitCompiledCpuFunction
>
jit
,
tensorflow
::
XlaJitCompiledCpuFunction
::
Compile
(
graph_def
,
xla_config
,
xla
::
ExecutableBuildOptions
()));
// Creates an instance which also allocates inputs.
tensorflow
::
XlaCompiledCpuFunction
instance
(
jit
->
StaticData
());
// Zeros out the inputs.
const
auto
*
program_shape
=
instance
.
ProgramShape
();
ASSERT_NE
(
nullptr
,
program_shape
);
for
(
int
i
=
0
;
i
<
program_shape
->
parameters_size
();
i
++
)
{
const
auto
&
shape
=
program_shape
->
parameters
(
i
);
if
(
shape
.
element_type
()
!=
xla
::
OPAQUE
)
{
std
::
memset
(
instance
.
arg_data
(
i
),
0
,
xla
::
ShapeUtil
::
ByteSizeOf
(
shape
));
}
}
// This is just a "don't crash" test. XLA behavior will be exercised
// more thoroughly in regression tests.
LOG
(
INFO
)
<<
"Running "
<<
component_name
;
ASSERT_TRUE
(
instance
.
Run
());
}
}
}
// namespace
}
// namespace runtime
}
// namespace dragnn
}
// namespace syntaxnet
Prev
1
…
10
11
12
13
14
15
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment