Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
a4bb31d0
Commit
a4bb31d0
authored
May 02, 2018
by
Terry Koo
Browse files
Export @195097388.
parent
dea7ecf6
Changes
296
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
3650 additions
and
0 deletions
+3650
-0
research/syntaxnet/dragnn/runtime/myelin/myelin_library.cc
research/syntaxnet/dragnn/runtime/myelin/myelin_library.cc
+70
-0
research/syntaxnet/dragnn/runtime/myelin/myelin_library.h
research/syntaxnet/dragnn/runtime/myelin/myelin_library.h
+49
-0
research/syntaxnet/dragnn/runtime/myelin/myelin_library_test.cc
...ch/syntaxnet/dragnn/runtime/myelin/myelin_library_test.cc
+83
-0
research/syntaxnet/dragnn/runtime/myelin/myelin_spec_utils.cc
...arch/syntaxnet/dragnn/runtime/myelin/myelin_spec_utils.cc
+186
-0
research/syntaxnet/dragnn/runtime/myelin/myelin_spec_utils.h
research/syntaxnet/dragnn/runtime/myelin/myelin_spec_utils.h
+93
-0
research/syntaxnet/dragnn/runtime/myelin/myelin_spec_utils_test.cc
...syntaxnet/dragnn/runtime/myelin/myelin_spec_utils_test.cc
+307
-0
research/syntaxnet/dragnn/runtime/myelin/myelin_tracing.cc
research/syntaxnet/dragnn/runtime/myelin/myelin_tracing.cc
+131
-0
research/syntaxnet/dragnn/runtime/myelin/myelin_tracing.h
research/syntaxnet/dragnn/runtime/myelin/myelin_tracing.h
+36
-0
research/syntaxnet/dragnn/runtime/myelin/myelin_tracing_test.cc
...ch/syntaxnet/dragnn/runtime/myelin/myelin_tracing_test.cc
+345
-0
research/syntaxnet/dragnn/runtime/myelin/myelination.cc
research/syntaxnet/dragnn/runtime/myelin/myelination.cc
+147
-0
research/syntaxnet/dragnn/runtime/myelin/myelination.h
research/syntaxnet/dragnn/runtime/myelin/myelination.h
+72
-0
research/syntaxnet/dragnn/runtime/myelin/myelination_test.cc
research/syntaxnet/dragnn/runtime/myelin/myelination_test.cc
+225
-0
research/syntaxnet/dragnn/runtime/myelin/sequence_myelin_dynamic_component.cc
...ragnn/runtime/myelin/sequence_myelin_dynamic_component.cc
+166
-0
research/syntaxnet/dragnn/runtime/myelin/sequence_myelin_dynamic_component_test.cc
.../runtime/myelin/sequence_myelin_dynamic_component_test.cc
+453
-0
research/syntaxnet/dragnn/runtime/myelin/testdata/myelination_output/master-spec
...nn/runtime/myelin/testdata/myelination_output/master-spec
+160
-0
research/syntaxnet/dragnn/runtime/myelin/testdata/myelination_output/rnn.flow
...ragnn/runtime/myelin/testdata/myelination_output/rnn.flow
+0
-0
research/syntaxnet/dragnn/runtime/myelin/testdata/myelination_output/tagger.flow
...nn/runtime/myelin/testdata/myelination_output/tagger.flow
+0
-0
research/syntaxnet/dragnn/runtime/network_states.cc
research/syntaxnet/dragnn/runtime/network_states.cc
+197
-0
research/syntaxnet/dragnn/runtime/network_states.h
research/syntaxnet/dragnn/runtime/network_states.h
+422
-0
research/syntaxnet/dragnn/runtime/network_states_test.cc
research/syntaxnet/dragnn/runtime/network_states_test.cc
+508
-0
No files found.
Too many changes to show.
To preserve performance only
296 of 296+
files are displayed.
Plain diff
Email patch
research/syntaxnet/dragnn/runtime/myelin/myelin_library.cc
0 → 100644
View file @
a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/myelin/myelin_library.h"
#include <string>
#include "syntaxnet/base.h"
#include "tensorflow/core/lib/strings/strcat.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
bool
PreMultipliedEmbeddings
::
Transform
(
sling
::
myelin
::
Flow
*
flow
)
{
bool
transformed_something
=
false
;
for
(
sling
::
myelin
::
Flow
::
Operation
*
matmul
:
flow
->
Find
({
"Gather"
,
"MatMul"
}))
{
if
(
matmul
->
indegree
()
!=
2
)
continue
;
sling
::
myelin
::
Flow
::
Variable
*
gathered
=
matmul
->
inputs
[
0
];
sling
::
myelin
::
Flow
::
Variable
*
weights
=
matmul
->
inputs
[
1
];
sling
::
myelin
::
Flow
::
Operation
*
gather
=
gathered
->
producer
;
if
(
gather
->
indegree
()
!=
2
)
continue
;
sling
::
myelin
::
Flow
::
Variable
*
embeddings
=
gather
->
inputs
[
0
];
sling
::
myelin
::
Flow
::
Variable
*
indices
=
gather
->
inputs
[
1
];
if
(
gathered
->
out
)
continue
;
if
(
!
weights
->
constant
())
continue
;
if
(
weights
->
rank
()
!=
2
)
continue
;
if
(
!
embeddings
->
constant
())
continue
;
if
(
embeddings
->
rank
()
!=
2
)
continue
;
if
(
embeddings
->
type
!=
weights
->
type
)
continue
;
// Add an operation to pre-multiply the embeddings and weights.
const
string
product_name
=
tensorflow
::
strings
::
StrCat
(
embeddings
->
name
,
"/"
,
weights
->
name
);
const
string
pre_multiply_name
=
tensorflow
::
strings
::
StrCat
(
product_name
,
"/PreMultiply"
);
sling
::
myelin
::
Flow
::
Variable
*
product
=
flow
->
AddVariable
(
product_name
,
weights
->
type
,
{
embeddings
->
dim
(
0
),
weights
->
dim
(
1
)});
flow
->
AddOperation
(
gather
->
func
,
pre_multiply_name
,
"MatMul"
,
{
embeddings
,
weights
},
{
product
});
// Convert the MatMul into a Gather on the pre-multiplied embeddings.
matmul
->
type
=
"Gather"
;
matmul
->
ReplaceInput
(
gathered
,
product
);
matmul
->
ReplaceInput
(
weights
,
indices
);
// Remove the original Gather if it is no longer used.
if
(
gathered
->
consumers
.
empty
())
flow
->
RemoveOperation
(
gather
);
transformed_something
=
true
;
}
return
transformed_something
;
}
}
// namespace runtime
}
// namespace dragnn
}
// namespace syntaxnet
research/syntaxnet/dragnn/runtime/myelin/myelin_library.h
0 → 100644
View file @
a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
// Myelin typers, transformers, and kernels specific to the DRAGNN runtime.
#ifndef DRAGNN_RUNTIME_MYELIN_MYELIN_LIBRARY_H_
#define DRAGNN_RUNTIME_MYELIN_MYELIN_LIBRARY_H_
#include "sling/myelin/flow.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
// Rearranges the flow to allow the "pre-multiplied embeddings" optimization.
// Specifically, performs the following transformation:
//
// tf.matmul(tf.gather(embeddings, indices), weights) =
// tf.gather(tf.matmul(embeddings, weights), indices)
//
// The transformation only applies if the embeddings and weights are constants.
// Myelin has constant folding transformations that will trigger and pre-compute
// the multiplication of the embeddings and weights.
//
// NB: There is already a PrecomputedEmbeddings transformer in Myelin but that
// operates on the Lookup op and expects an intervening Reshape.
class
PreMultipliedEmbeddings
:
public
sling
::
myelin
::
Transformer
{
public:
// Implements Transformer.
bool
Transform
(
sling
::
myelin
::
Flow
*
flow
)
override
;
};
}
// namespace runtime
}
// namespace dragnn
}
// namespace syntaxnet
#endif // DRAGNN_RUNTIME_MYELIN_MYELIN_LIBRARY_H_
research/syntaxnet/dragnn/runtime/myelin/myelin_library_test.cc
0 → 100644
View file @
a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/myelin/myelin_library.h"
#include <vector>
#include "tensorflow/core/platform/test.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
namespace
{
// Tests that PreMultipliedEmbeddings does nothing on an empty Flow.
TEST
(
PreMultipliedEmbeddingsTest
,
DoesNothingOnEmptyFlow
)
{
sling
::
myelin
::
Flow
flow
;
PreMultipliedEmbeddings
transformer
;
EXPECT_FALSE
(
transformer
.
Transform
(
&
flow
));
}
// Tests that PreMultipliedEmbeddings can rearrange a MatMul of a Gather into a
// Gather of a pre-multiplied matrix.
TEST
(
PreMultipliedEmbeddingsTest
,
AppliesPreMultiplication
)
{
sling
::
myelin
::
Flow
flow
;
sling
::
myelin
::
Flow
::
Function
*
function
=
flow
.
AddFunction
(
"test_function"
);
sling
::
myelin
::
Flow
::
Variable
*
indices
=
flow
.
AddVariable
(
"indices"
,
sling
::
myelin
::
DT_INT32
,
{
1
});
sling
::
myelin
::
Flow
::
Variable
*
embeddings
=
flow
.
AddVariable
(
"embeddings"
,
sling
::
myelin
::
DT_FLOAT
,
{
10
,
20
});
sling
::
myelin
::
Flow
::
Variable
*
gathered
=
flow
.
AddVariable
(
"gathered"
,
sling
::
myelin
::
DT_FLOAT
,
{
1
,
20
});
sling
::
myelin
::
Flow
::
Variable
*
weights
=
flow
.
AddVariable
(
"weights"
,
sling
::
myelin
::
DT_FLOAT
,
{
20
,
30
});
sling
::
myelin
::
Flow
::
Variable
*
output
=
flow
.
AddVariable
(
"output"
,
sling
::
myelin
::
DT_FLOAT
,
{
1
,
30
});
flow
.
AddOperation
(
function
,
"gather"
,
"Gather"
,
{
embeddings
,
indices
},
{
gathered
});
flow
.
AddOperation
(
function
,
"matmul"
,
"MatMul"
,
{
gathered
,
weights
},
{
output
});
// Attach constant data to the matrices.
const
std
::
vector
<
float
>
floats
(
20
*
30
);
// big enough for both
embeddings
->
SetData
(
floats
.
data
(),
10
*
20
*
sizeof
(
float
));
weights
->
SetData
(
floats
.
data
(),
20
*
30
*
sizeof
(
float
));
PreMultipliedEmbeddings
transformer
;
ASSERT_TRUE
(
transformer
.
Transform
(
&
flow
));
sling
::
myelin
::
Flow
::
Variable
*
product
=
flow
.
Var
(
"embeddings/weights"
);
ASSERT_NE
(
product
,
nullptr
);
ASSERT_EQ
(
product
->
rank
(),
2
);
EXPECT_EQ
(
product
->
dim
(
0
),
10
);
EXPECT_EQ
(
product
->
dim
(
1
),
30
);
sling
::
myelin
::
Flow
::
Operation
*
pre_multiply
=
flow
.
Op
(
"embeddings/weights/PreMultiply"
);
ASSERT_NE
(
pre_multiply
,
nullptr
);
ASSERT_EQ
(
pre_multiply
->
indegree
(),
2
);
ASSERT_EQ
(
pre_multiply
->
outdegree
(),
1
);
EXPECT_EQ
(
pre_multiply
->
type
,
"MatMul"
);
EXPECT_EQ
(
pre_multiply
->
inputs
[
0
],
embeddings
);
EXPECT_EQ
(
pre_multiply
->
inputs
[
1
],
weights
);
EXPECT_EQ
(
pre_multiply
->
outputs
[
0
],
product
);
}
}
// namespace
}
// namespace runtime
}
// namespace dragnn
}
// namespace syntaxnet
research/syntaxnet/dragnn/runtime/myelin/myelin_spec_utils.cc
0 → 100644
View file @
a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/myelin/myelin_spec_utils.h"
#include <algorithm>
#include "dragnn/runtime/myelin/myelin_library.h"
#include "sling/base/status.h"
#include "sling/file/file.h"
#include "sling/myelin/kernel/tensorflow.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
const
char
*
const
kMyelinFlowResourceName
=
"myelin-flow"
;
const
char
*
const
kMyelinFlowResourceFileFormat
=
"model"
;
const
char
*
const
kMyelinFlowResourceRecordFormat
=
"sling.myelin.Flow"
;
tensorflow
::
Status
LookupMyelinFlowResource
(
const
ComponentSpec
&
component_spec
,
const
Resource
**
flow_resource
)
{
const
Resource
*
found_resource
=
nullptr
;
for
(
const
Resource
&
resource
:
component_spec
.
resource
())
{
if
(
resource
.
name
()
!=
kMyelinFlowResourceName
)
continue
;
if
(
found_resource
!=
nullptr
)
{
return
tensorflow
::
errors
::
InvalidArgument
(
"Component '"
,
component_spec
.
name
(),
"' contains duplicate Myelin Flow resources"
);
}
if
(
resource
.
part_size
()
!=
1
)
{
return
tensorflow
::
errors
::
InvalidArgument
(
"Component '"
,
component_spec
.
name
(),
"' has malformed Myelin Flow resource; expected 1 part"
);
}
const
Part
&
part
=
resource
.
part
(
0
);
if
(
part
.
file_format
()
!=
kMyelinFlowResourceFileFormat
)
{
return
tensorflow
::
errors
::
InvalidArgument
(
"Component '"
,
component_spec
.
name
(),
"' has malformed Myelin Flow resource; wrong file format"
);
}
if
(
part
.
record_format
()
!=
kMyelinFlowResourceRecordFormat
)
{
return
tensorflow
::
errors
::
InvalidArgument
(
"Component '"
,
component_spec
.
name
(),
"' has malformed Myelin Flow resource; wrong record format"
);
}
found_resource
=
&
resource
;
}
if
(
found_resource
==
nullptr
)
{
return
tensorflow
::
errors
::
NotFound
(
"Component '"
,
component_spec
.
name
(),
"' has no Myelin Flow resource"
);
}
// Success; make modifications.
*
flow_resource
=
found_resource
;
return
tensorflow
::
Status
::
OK
();
}
tensorflow
::
Status
AddMyelinFlowResource
(
const
string
&
path
,
ComponentSpec
*
component_spec
)
{
if
(
std
::
any_of
(
component_spec
->
resource
().
begin
(),
component_spec
->
resource
().
end
(),
[](
const
Resource
&
resource
)
{
return
resource
.
name
()
==
kMyelinFlowResourceName
;
}))
{
return
tensorflow
::
errors
::
InvalidArgument
(
"Component '"
,
component_spec
->
name
(),
"' already contains a Myelin Flow resource"
);
}
// Success; make modifications.
Resource
*
resource
=
component_spec
->
add_resource
();
resource
->
set_name
(
kMyelinFlowResourceName
);
Part
*
part
=
resource
->
add_part
();
part
->
set_file_pattern
(
path
);
part
->
set_file_format
(
kMyelinFlowResourceFileFormat
);
part
->
set_record_format
(
kMyelinFlowResourceRecordFormat
);
return
tensorflow
::
Status
::
OK
();
}
tensorflow
::
Status
LoadMyelinFlow
(
const
string
&
flow_path
,
sling
::
myelin
::
Flow
*
flow
)
{
sling
::
File
::
Init
();
const
sling
::
Status
status
=
flow
->
Load
(
flow_path
);
if
(
!
status
.
ok
())
{
return
tensorflow
::
errors
::
Internal
(
"Failed to load Myelin Flow from '"
,
flow_path
,
": "
,
status
.
ToString
());
}
// Mark cell inputs and outputs.
for
(
sling
::
myelin
::
Flow
::
Variable
*
variable
:
flow
->
vars
())
{
for
(
tensorflow
::
StringPiece
alias
:
variable
->
aliases
)
{
if
(
tensorflow
::
str_util
::
StartsWith
(
alias
,
"INPUT/"
))
{
variable
->
in
=
true
;
}
if
(
tensorflow
::
str_util
::
StartsWith
(
alias
,
"OUTPUT/"
))
{
variable
->
out
=
true
;
}
}
}
return
tensorflow
::
Status
::
OK
();
}
void
RegisterMyelinLibraries
(
sling
::
myelin
::
Library
*
library
)
{
// TODO(googleuser): Add more libraries?
sling
::
myelin
::
RegisterTensorflowLibrary
(
library
);
library
->
RegisterTransformer
(
new
PreMultipliedEmbeddings
());
}
std
::
set
<
string
>
GetRecurrentLayerNames
(
const
sling
::
myelin
::
Flow
&
flow
)
{
std
::
set
<
string
>
names
;
for
(
const
sling
::
myelin
::
Flow
::
Variable
*
variable
:
flow
.
vars
())
{
for
(
tensorflow
::
StringPiece
alias
:
variable
->
aliases
)
{
if
(
!
tensorflow
::
str_util
::
ConsumePrefix
(
&
alias
,
"INPUT/"
))
continue
;
if
(
tensorflow
::
str_util
::
ConsumePrefix
(
&
alias
,
"fixed_channel_"
))
{
continue
;
}
if
(
tensorflow
::
str_util
::
ConsumePrefix
(
&
alias
,
"linked_channel_"
))
{
continue
;
}
names
.
insert
(
alias
.
ToString
());
}
}
return
names
;
}
std
::
set
<
string
>
GetOutputLayerNames
(
const
sling
::
myelin
::
Flow
&
flow
)
{
std
::
set
<
string
>
names
;
for
(
const
sling
::
myelin
::
Flow
::
Variable
*
variable
:
flow
.
vars
())
{
for
(
tensorflow
::
StringPiece
alias
:
variable
->
aliases
)
{
if
(
!
tensorflow
::
str_util
::
ConsumePrefix
(
&
alias
,
"OUTPUT/"
))
continue
;
names
.
insert
(
alias
.
ToString
());
}
}
return
names
;
}
string
MakeMyelinInputFixedFeatureIdName
(
int
channel_id
,
int
index
)
{
return
tensorflow
::
strings
::
StrCat
(
"INPUT/fixed_channel_"
,
channel_id
,
"_index_"
,
index
,
"_ids"
);
}
string
MakeMyelinInputLinkedActivationVectorName
(
int
channel_id
)
{
return
tensorflow
::
strings
::
StrCat
(
"INPUT/linked_channel_"
,
channel_id
,
"_activations"
);
}
string
MakeMyelinInputLinkedOutOfBoundsIndicatorName
(
int
channel_id
)
{
return
tensorflow
::
strings
::
StrCat
(
"INPUT/linked_channel_"
,
channel_id
,
"_out_of_bounds"
);
}
string
MakeMyelinInputRecurrentLayerName
(
const
string
&
layer_name
)
{
return
tensorflow
::
strings
::
StrCat
(
"INPUT/"
,
layer_name
);
}
string
MakeMyelinOutputLayerName
(
const
string
&
layer_name
)
{
return
tensorflow
::
strings
::
StrCat
(
"OUTPUT/"
,
layer_name
);
}
}
// namespace runtime
}
// namespace dragnn
}
// namespace syntaxnet
research/syntaxnet/dragnn/runtime/myelin/myelin_spec_utils.h
0 → 100644
View file @
a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
// Utils for working with specifications of Myelin-based DRAGNN runtime models.
#ifndef DRAGNN_RUNTIME_MYELIN_MYELIN_SPEC_UTILS_H_
#define DRAGNN_RUNTIME_MYELIN_MYELIN_SPEC_UTILS_H_
#include <set>
#include <string>
#include "dragnn/protos/spec.pb.h"
#include "syntaxnet/base.h"
#include "sling/myelin/compute.h"
#include "sling/myelin/flow.h"
#include "tensorflow/core/lib/core/status.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
// The name, file format, and record format of the resource that contains the
// Myelin Flow for each component.
extern
const
char
*
const
kMyelinFlowResourceName
;
extern
const
char
*
const
kMyelinFlowResourceFileFormat
;
extern
const
char
*
const
kMyelinFlowResourceRecordFormat
;
// Points |flow_resource| to the resource in the |component_spec| that specifies
// the Myelin Flow file. On error, returns non-OK and modifies nothing.
tensorflow
::
Status
LookupMyelinFlowResource
(
const
ComponentSpec
&
component_spec
,
const
Resource
**
flow_resource
);
// Adds a resource to the |component_spec| that specifies the Myelin Flow file
// at the |path|. On error, returns non-OK and modifies nothing.
tensorflow
::
Status
AddMyelinFlowResource
(
const
string
&
path
,
ComponentSpec
*
component_spec
);
// Loads a Myelin Flow file from the |flow_path| into the |flow| and ensures
// that inputs and outputs are marked properly. On error, returns non-OK.
tensorflow
::
Status
LoadMyelinFlow
(
const
string
&
flow_path
,
sling
::
myelin
::
Flow
*
flow
);
// Registers a standard set of libraries in the Myelin |library|.
void
RegisterMyelinLibraries
(
sling
::
myelin
::
Library
*
library
);
// Returns the set of recurrent input layer names in the |flow|. A recurrent
// input layer is defined as any input that is not a fixed or linked feature.
//
// Note that recurrent input layers differ from recurrent linked features. The
// latter are linked features that have been configured to refer to the current
// component, while the former are hard-coded in the network structure itself.
// See, for example, the context tensor arrays that hold the cell state in the
// LstmNetwork.
//
// TODO(googleuser): Use a more robust naming scheme for recurrent inputs?
std
::
set
<
string
>
GetRecurrentLayerNames
(
const
sling
::
myelin
::
Flow
&
flow
);
// Returns the set of output layer names in the |flow|.
std
::
set
<
string
>
GetOutputLayerNames
(
const
sling
::
myelin
::
Flow
&
flow
);
// Returns the name of the Myelin input for the ID of the |index|'th feature in
// the |channel_id|'th fixed feature channel.
string
MakeMyelinInputFixedFeatureIdName
(
int
channel_id
,
int
index
);
// Returns the names of the Myelin inputs for the source activation vector and
// out-of-bounds indicator of the |channel_id|'th linked feature channel.
string
MakeMyelinInputLinkedActivationVectorName
(
int
channel_id
);
string
MakeMyelinInputLinkedOutOfBoundsIndicatorName
(
int
channel_id
);
// Returns the name of the Myelin input for the hard-coded recurrent layer named
// |layer_name|.
string
MakeMyelinInputRecurrentLayerName
(
const
string
&
layer_name
);
// Returns the name of the Myelin output for the layer named |layer_name|.
string
MakeMyelinOutputLayerName
(
const
string
&
layer_name
);
}
// namespace runtime
}
// namespace dragnn
}
// namespace syntaxnet
#endif // DRAGNN_RUNTIME_MYELIN_MYELIN_SPEC_UTILS_H_
research/syntaxnet/dragnn/runtime/myelin/myelin_spec_utils_test.cc
0 → 100644
View file @
a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/myelin/myelin_spec_utils.h"
#include <set>
#include <string>
#include "dragnn/core/test/generic.h"
#include "dragnn/protos/spec.pb.h"
#include "syntaxnet/base.h"
#include "sling/file/file.h"
#include "sling/myelin/compute.h"
#include "sling/myelin/flow.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/platform/test.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
namespace
{
TEST
(
MyelinSpecUtilsTest
,
AddAndLookupMyelinFlowResource
)
{
ComponentSpec
component_spec
;
TF_ASSERT_OK
(
AddMyelinFlowResource
(
"/dev/null"
,
&
component_spec
));
const
Resource
*
resource
=
nullptr
;
TF_ASSERT_OK
(
LookupMyelinFlowResource
(
component_spec
,
&
resource
));
ASSERT_NE
(
resource
,
nullptr
);
EXPECT_EQ
(
resource
->
name
(),
kMyelinFlowResourceName
);
ASSERT_EQ
(
resource
->
part_size
(),
1
);
EXPECT_EQ
(
resource
->
part
(
0
).
file_pattern
(),
"/dev/null"
);
EXPECT_EQ
(
resource
->
part
(
0
).
file_format
(),
kMyelinFlowResourceFileFormat
);
EXPECT_EQ
(
resource
->
part
(
0
).
record_format
(),
kMyelinFlowResourceRecordFormat
);
}
TEST
(
MyelinSpecUtilsTest
,
LookupMyelinFlowResourceMissing
)
{
ComponentSpec
component_spec
;
const
Resource
*
resource
=
nullptr
;
EXPECT_THAT
(
LookupMyelinFlowResource
(
component_spec
,
&
resource
),
test
::
IsErrorWithSubstr
(
"has no Myelin Flow resource"
));
component_spec
.
add_resource
()
->
set_name
(
"foo"
);
EXPECT_THAT
(
LookupMyelinFlowResource
(
component_spec
,
&
resource
),
test
::
IsErrorWithSubstr
(
"has no Myelin Flow resource"
));
component_spec
.
add_resource
()
->
set_name
(
"bar"
);
EXPECT_THAT
(
LookupMyelinFlowResource
(
component_spec
,
&
resource
),
test
::
IsErrorWithSubstr
(
"has no Myelin Flow resource"
));
}
TEST
(
MyelinSpecUtilsTest
,
LookupMyelinFlowResourceWrongName
)
{
ComponentSpec
component_spec
;
TF_ASSERT_OK
(
AddMyelinFlowResource
(
"/dev/null"
,
&
component_spec
));
component_spec
.
mutable_resource
(
0
)
->
set_name
(
"bad"
);
const
Resource
*
resource
=
nullptr
;
EXPECT_THAT
(
LookupMyelinFlowResource
(
component_spec
,
&
resource
),
test
::
IsErrorWithSubstr
(
"has no Myelin Flow resource"
));
}
TEST
(
MyelinSpecUtilsTest
,
LookupMyelinFlowResourceWrongFileFormat
)
{
ComponentSpec
component_spec
;
TF_ASSERT_OK
(
AddMyelinFlowResource
(
"/dev/null"
,
&
component_spec
));
component_spec
.
mutable_resource
(
0
)
->
mutable_part
(
0
)
->
set_file_format
(
"bad"
);
const
Resource
*
resource
=
nullptr
;
EXPECT_THAT
(
LookupMyelinFlowResource
(
component_spec
,
&
resource
),
test
::
IsErrorWithSubstr
(
"wrong file format"
));
}
TEST
(
MyelinSpecUtilsTest
,
LookupMyelinFlowResourceWrongRecordFormat
)
{
ComponentSpec
component_spec
;
TF_ASSERT_OK
(
AddMyelinFlowResource
(
"/dev/null"
,
&
component_spec
));
component_spec
.
mutable_resource
(
0
)
->
mutable_part
(
0
)
->
set_record_format
(
"bad"
);
const
Resource
*
resource
=
nullptr
;
EXPECT_THAT
(
LookupMyelinFlowResource
(
component_spec
,
&
resource
),
test
::
IsErrorWithSubstr
(
"wrong record format"
));
}
TEST
(
MyelinSpecUtilsTest
,
LookupMyelinFlowResourceWrongNumberOfParts
)
{
ComponentSpec
component_spec
;
TF_ASSERT_OK
(
AddMyelinFlowResource
(
"/dev/null"
,
&
component_spec
));
component_spec
.
mutable_resource
(
0
)
->
add_part
();
const
Resource
*
resource
=
nullptr
;
EXPECT_THAT
(
LookupMyelinFlowResource
(
component_spec
,
&
resource
),
test
::
IsErrorWithSubstr
(
"expected 1 part"
));
}
TEST
(
MyelinSpecUtilsTest
,
LookupMyelinFlowResourceDuplicate
)
{
ComponentSpec
component_spec
;
TF_ASSERT_OK
(
AddMyelinFlowResource
(
"/dev/null"
,
&
component_spec
));
component_spec
.
add_resource
()
->
set_name
(
kMyelinFlowResourceName
);
const
Resource
*
resource
=
nullptr
;
EXPECT_THAT
(
LookupMyelinFlowResource
(
component_spec
,
&
resource
),
test
::
IsErrorWithSubstr
(
"contains duplicate Myelin Flow resource"
));
}
TEST
(
MyelinSpecUtilsTest
,
AddMyelinFlowResourceDuplicate
)
{
ComponentSpec
component_spec
;
TF_ASSERT_OK
(
AddMyelinFlowResource
(
"/dev/null"
,
&
component_spec
));
EXPECT_THAT
(
AddMyelinFlowResource
(
"another/flow"
,
&
component_spec
),
test
::
IsErrorWithSubstr
(
"already contains a Myelin Flow resource"
));
}
TEST
(
MyelinSpecUtilsTest
,
LoadMyelinFlowInvalidPath
)
{
sling
::
myelin
::
Flow
flow
;
EXPECT_THAT
(
LoadMyelinFlow
(
"invalid/path"
,
&
flow
),
test
::
IsErrorWithSubstr
(
"Failed to load Myelin Flow"
));
}
TEST
(
MyelinSpecUtilsTest
,
LoadMyelinFlowValidFile
)
{
// Build and write a Flow file with some variables that are annotated with
// input and output aliases.
sling
::
myelin
::
Flow
original_flow
;
original_flow
.
AddVariable
(
"input"
,
sling
::
myelin
::
DT_FLOAT
,
sling
::
myelin
::
Shape
())
->
aliases
=
{
"INPUT/a"
};
original_flow
.
AddVariable
(
"output"
,
sling
::
myelin
::
DT_FLOAT
,
sling
::
myelin
::
Shape
())
->
aliases
=
{
"OUTPUT/b"
};
original_flow
.
AddVariable
(
"both"
,
sling
::
myelin
::
DT_FLOAT
,
sling
::
myelin
::
Shape
())
->
aliases
=
{
"INPUT/c"
,
"OUTPUT/d"
};
original_flow
.
AddVariable
(
"neither"
,
sling
::
myelin
::
DT_FLOAT
,
sling
::
myelin
::
Shape
());
const
string
flow_path
=
tensorflow
::
io
::
JoinPath
(
tensorflow
::
testing
::
TmpDir
(),
"foo.flow"
);
sling
::
File
::
Init
();
original_flow
.
Save
(
flow_path
);
// Load the Flow file into a fresh Flow and check that inputs and outputs are
// marked as such.
sling
::
myelin
::
Flow
flow
;
TF_ASSERT_OK
(
LoadMyelinFlow
(
flow_path
,
&
flow
));
ASSERT_NE
(
flow
.
Var
(
"input"
),
nullptr
);
EXPECT_TRUE
(
flow
.
Var
(
"input"
)
->
in
);
EXPECT_FALSE
(
flow
.
Var
(
"input"
)
->
out
);
ASSERT_NE
(
flow
.
Var
(
"output"
),
nullptr
);
EXPECT_FALSE
(
flow
.
Var
(
"output"
)
->
in
);
EXPECT_TRUE
(
flow
.
Var
(
"output"
)
->
out
);
ASSERT_NE
(
flow
.
Var
(
"both"
),
nullptr
);
EXPECT_TRUE
(
flow
.
Var
(
"both"
)
->
in
);
EXPECT_TRUE
(
flow
.
Var
(
"both"
)
->
out
);
ASSERT_NE
(
flow
.
Var
(
"neither"
),
nullptr
);
EXPECT_FALSE
(
flow
.
Var
(
"neither"
)
->
in
);
EXPECT_FALSE
(
flow
.
Var
(
"neither"
)
->
out
);
}
TEST
(
MyelinSpecUtilsTest
,
RegisterMyelinLibraries
)
{
sling
::
myelin
::
Library
library
;
RegisterMyelinLibraries
(
&
library
);
// The |library| should contain something.
EXPECT_GT
(
library
.
transformers
().
size
()
+
library
.
typers
().
size
(),
0
);
}
TEST
(
MyelinSpecUtilsTest
,
GetRecurrentLayerNamesEmpty
)
{
sling
::
myelin
::
Flow
flow
;
const
std
::
set
<
string
>
expected_names
;
EXPECT_EQ
(
GetRecurrentLayerNames
(
flow
),
expected_names
);
}
TEST
(
MyelinSpecUtilsTest
,
GetRecurrentLayerNamesVariablesWithNoAliases
)
{
sling
::
myelin
::
Flow
flow
;
flow
.
AddVariable
(
"x"
,
sling
::
myelin
::
DT_FLOAT
,
{});
flow
.
AddVariable
(
"y"
,
sling
::
myelin
::
DT_INT32
,
{});
const
std
::
set
<
string
>
expected_names
;
EXPECT_EQ
(
GetRecurrentLayerNames
(
flow
),
expected_names
);
}
TEST
(
MyelinSpecUtilsTest
,
GetRecurrentLayerNamesVariablesWithAliases
)
{
sling
::
myelin
::
Flow
flow
;
flow
.
AddVariable
(
"x"
,
sling
::
myelin
::
DT_FLOAT
,
{})
->
aliases
=
{
"foo"
,
"bar"
};
flow
.
AddVariable
(
"y"
,
sling
::
myelin
::
DT_INT32
,
{})
->
aliases
=
{
"INPUT/y"
,
//
"INPUT/fixed_channel_0_index_0_ids"
,
//
"INPUT/linked_channel_0_activations"
};
flow
.
AddVariable
(
"z"
,
sling
::
myelin
::
DT_INT32
,
{})
->
aliases
=
{
"OUTPUT/z"
};
const
std
::
set
<
string
>
expected_names
=
{
"y"
};
EXPECT_EQ
(
GetRecurrentLayerNames
(
flow
),
expected_names
);
}
TEST
(
MyelinSpecUtilsTest
,
GetRecurrentLayerNamesVariablesWithMultipleAliases
)
{
sling
::
myelin
::
Flow
flow
;
flow
.
AddVariable
(
"x"
,
sling
::
myelin
::
DT_FLOAT
,
{})
->
aliases
=
{
"foo"
,
"bar"
};
flow
.
AddVariable
(
"y"
,
sling
::
myelin
::
DT_INT32
,
{})
->
aliases
=
{
"INPUT/recurrent_1"
,
//
"INPUT/recurrent_2"
,
//
"INPUT/fixed_channel_0_index_0_ids"
,
//
"INPUT/linked_channel_0_activations"
};
flow
.
AddVariable
(
"z"
,
sling
::
myelin
::
DT_INT32
,
{})
->
aliases
=
{
"OUTPUT/output_1"
,
//
"OUTPUT/output_2"
};
const
std
::
set
<
string
>
expected_names
=
{
"recurrent_1"
,
"recurrent_2"
};
EXPECT_EQ
(
GetRecurrentLayerNames
(
flow
),
expected_names
);
}
TEST
(
MyelinSpecUtilsTest
,
GetOutputLayerNamesEmpty
)
{
sling
::
myelin
::
Flow
flow
;
const
std
::
set
<
string
>
expected_names
;
EXPECT_EQ
(
GetOutputLayerNames
(
flow
),
expected_names
);
}
TEST
(
MyelinSpecUtilsTest
,
GetOutputLayerNamesVariablesWithNoAliases
)
{
sling
::
myelin
::
Flow
flow
;
flow
.
AddVariable
(
"x"
,
sling
::
myelin
::
DT_FLOAT
,
{});
flow
.
AddVariable
(
"y"
,
sling
::
myelin
::
DT_INT32
,
{});
const
std
::
set
<
string
>
expected_names
;
EXPECT_EQ
(
GetOutputLayerNames
(
flow
),
expected_names
);
}
TEST
(
MyelinSpecUtilsTest
,
GetOutputLayerNamesVariablesWithAliases
)
{
sling
::
myelin
::
Flow
flow
;
flow
.
AddVariable
(
"x"
,
sling
::
myelin
::
DT_FLOAT
,
{})
->
aliases
=
{
"foo"
,
"bar"
};
flow
.
AddVariable
(
"y"
,
sling
::
myelin
::
DT_INT32
,
{})
->
aliases
=
{
"INPUT/y"
,
//
"INPUT/fixed_channel_0_index_0_ids"
,
//
"INPUT/linked_channel_0_activations"
};
flow
.
AddVariable
(
"z"
,
sling
::
myelin
::
DT_INT32
,
{})
->
aliases
=
{
"OUTPUT/z"
};
const
std
::
set
<
string
>
expected_names
=
{
"z"
};
EXPECT_EQ
(
GetOutputLayerNames
(
flow
),
expected_names
);
}
TEST
(
MyelinSpecUtilsTest
,
GetOutputLayerNamesVariablesWithMultipleAliases
)
{
sling
::
myelin
::
Flow
flow
;
flow
.
AddVariable
(
"x"
,
sling
::
myelin
::
DT_FLOAT
,
{})
->
aliases
=
{
"foo"
,
"bar"
};
flow
.
AddVariable
(
"y"
,
sling
::
myelin
::
DT_INT32
,
{})
->
aliases
=
{
"INPUT/recurrent_1"
,
//
"INPUT/recurrent_2"
,
//
"INPUT/fixed_channel_0_index_0_ids"
,
//
"INPUT/linked_channel_0_activations"
};
flow
.
AddVariable
(
"z"
,
sling
::
myelin
::
DT_INT32
,
{})
->
aliases
=
{
"OUTPUT/output_1"
,
//
"OUTPUT/output_2"
};
const
std
::
set
<
string
>
expected_names
=
{
"output_1"
,
"output_2"
};
EXPECT_EQ
(
GetOutputLayerNames
(
flow
),
expected_names
);
}
TEST
(
MyelinSpecUtilsTest
,
MakeMyelinInputFixedFeatureIdName
)
{
EXPECT_EQ
(
MakeMyelinInputFixedFeatureIdName
(
0
,
1
),
"INPUT/fixed_channel_0_index_1_ids"
);
EXPECT_EQ
(
MakeMyelinInputFixedFeatureIdName
(
1
,
0
),
"INPUT/fixed_channel_1_index_0_ids"
);
}
TEST
(
MyelinSpecUtilsTest
,
MakeMyelinInputLinkedActivationVectorName
)
{
EXPECT_EQ
(
MakeMyelinInputLinkedActivationVectorName
(
0
),
"INPUT/linked_channel_0_activations"
);
EXPECT_EQ
(
MakeMyelinInputLinkedActivationVectorName
(
1
),
"INPUT/linked_channel_1_activations"
);
}
TEST
(
MyelinSpecUtilsTest
,
MakeMyelinInputLinkedOutOfBoundsIndicatorName
)
{
EXPECT_EQ
(
MakeMyelinInputLinkedOutOfBoundsIndicatorName
(
0
),
"INPUT/linked_channel_0_out_of_bounds"
);
EXPECT_EQ
(
MakeMyelinInputLinkedOutOfBoundsIndicatorName
(
1
),
"INPUT/linked_channel_1_out_of_bounds"
);
}
TEST
(
MyelinSpecUtilsTest
,
MakeMyelinInputRecurrentLayerName
)
{
EXPECT_EQ
(
MakeMyelinInputRecurrentLayerName
(
"foo"
),
"INPUT/foo"
);
EXPECT_EQ
(
MakeMyelinInputRecurrentLayerName
(
"bar_baz"
),
"INPUT/bar_baz"
);
}
TEST
(
MyelinSpecUtilsTest
,
MakeMyelinOutputLayerName
)
{
EXPECT_EQ
(
MakeMyelinOutputLayerName
(
"foo"
),
"OUTPUT/foo"
);
EXPECT_EQ
(
MakeMyelinOutputLayerName
(
"bar_baz"
),
"OUTPUT/bar_baz"
);
}
}
// namespace
}
// namespace runtime
}
// namespace dragnn
}
// namespace syntaxnet
research/syntaxnet/dragnn/runtime/myelin/myelin_tracing.cc
0 → 100644
View file @
a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/myelin/myelin_tracing.h"
#include <map>
#include <string>
#include "syntaxnet/base.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
namespace
{
// Copies |num_values| |T|s from |data| into the |tensor_trace|. If |T| does
// not match the |type|, returns false and modifies nothing. The bool return
// allows this function to be chained until a matching type is found.
template
<
class
T
>
bool
TryCopyValues
(
sling
::
myelin
::
Type
type
,
const
char
*
data
,
int
num_values
,
CellTensorTrace
*
tensor_trace
)
{
if
(
sling
::
myelin
::
Traits
<
T
>
().
type
()
!=
type
)
return
false
;
const
T
*
begin
=
reinterpret_cast
<
const
T
*>
(
data
);
const
T
*
end
=
begin
+
num_values
;
tensor_trace
->
clear_value
();
for
(;
begin
!=
end
;
++
begin
)
tensor_trace
->
add_value
(
*
begin
);
return
true
;
}
}
// namespace
void
TraceMyelinInstance
(
sling
::
myelin
::
Instance
*
instance
,
CellTrace
*
cell_trace
)
{
const
sling
::
myelin
::
Cell
&
cell
=
*
instance
->
cell
();
cell_trace
->
Clear
();
cell_trace
->
set_name
(
cell
.
name
());
// Collect steps and tensors in sorted maps for deterministic ordering.
std
::
map
<
string
,
const
sling
::
myelin
::
Step
*>
steps
;
std
::
map
<
string
,
sling
::
myelin
::
Tensor
*>
tensors
;
for
(
const
sling
::
myelin
::
Step
*
step
:
cell
.
steps
())
{
steps
[
step
->
name
()]
=
step
;
for
(
sling
::
myelin
::
Tensor
*
tensor
:
step
->
inputs
())
{
tensors
[
tensor
->
name
()]
=
tensor
;
}
for
(
sling
::
myelin
::
Tensor
*
tensor
:
step
->
outputs
())
{
tensors
[
tensor
->
name
()]
=
tensor
;
}
}
// Trace each step as an operation.
for
(
const
auto
&
it
:
steps
)
{
const
sling
::
myelin
::
Step
*
step
=
it
.
second
;
CellOperationTrace
*
operation_trace
=
cell_trace
->
add_operation
();
operation_trace
->
set_name
(
step
->
name
());
operation_trace
->
set_type
(
step
->
type
());
operation_trace
->
set_kernel
(
step
->
kernel
()
->
Name
());
for
(
sling
::
myelin
::
Tensor
*
tensor
:
step
->
inputs
())
{
operation_trace
->
add_input
(
tensor
->
name
());
}
for
(
sling
::
myelin
::
Tensor
*
tensor
:
step
->
outputs
())
{
operation_trace
->
add_output
(
tensor
->
name
());
}
}
// Trace each tensor and its value.
for
(
const
auto
&
it
:
tensors
)
{
sling
::
myelin
::
Tensor
*
tensor
=
it
.
second
;
if
(
!
tensor
->
IsLocal
())
continue
;
// ignore globals; e.g., weight matrices
const
string
&
name
=
tensor
->
name
();
const
sling
::
myelin
::
Type
type
=
tensor
->
type
();
// Find the variable data for the |tensor|. Note that ref tensors need to
// be dereferenced.
const
char
*
data
=
instance
->
GetAddress
(
tensor
);
if
(
tensor
->
ref
())
data
=
*
reinterpret_cast
<
const
char
*
const
*>
(
data
);
const
int
size
=
tensor
->
aligned
().
elements
();
CellTensorTrace
*
tensor_trace
=
cell_trace
->
add_tensor
();
tensor_trace
->
set_name
(
name
);
tensor_trace
->
set_type
(
sling
::
myelin
::
TypeTraits
::
of
(
type
).
name
());
for
(
int
i
=
0
;
i
<
tensor
->
rank
();
++
i
)
{
tensor_trace
->
add_dimension
(
tensor
->
dim
(
i
));
tensor_trace
->
add_aligned_dimension
(
tensor
->
aligned
(
i
));
}
switch
(
tensor
->
order
())
{
case
sling
::
myelin
::
ROW_MAJOR
:
tensor_trace
->
set_order
(
CellTensorTrace
::
ORDER_ROW_MAJOR
);
break
;
case
sling
::
myelin
::
COLUMN_MAJOR
:
tensor_trace
->
set_order
(
CellTensorTrace
::
ORDER_COLUMN_MAJOR
);
break
;
default:
break
;
}
// Try copying tensor data using all relevant types. At most one attempt
// will succeed and modify the |tensor_trace|.
if
(
!
TryCopyValues
<
float
>
(
type
,
data
,
size
,
tensor_trace
)
&&
!
TryCopyValues
<
double
>
(
type
,
data
,
size
,
tensor_trace
)
&&
!
TryCopyValues
<
bool
>
(
type
,
data
,
size
,
tensor_trace
)
&&
!
TryCopyValues
<
int8
>
(
type
,
data
,
size
,
tensor_trace
)
&&
!
TryCopyValues
<
int16
>
(
type
,
data
,
size
,
tensor_trace
)
&&
!
TryCopyValues
<
int32
>
(
type
,
data
,
size
,
tensor_trace
)
&&
!
TryCopyValues
<
int64
>
(
type
,
data
,
size
,
tensor_trace
)
&&
!
TryCopyValues
<
uint8
>
(
type
,
data
,
size
,
tensor_trace
)
&&
!
TryCopyValues
<
uint16
>
(
type
,
data
,
size
,
tensor_trace
))
{
LOG
(
WARNING
)
<<
"Can't convert data for tensor "
<<
name
<<
" with type "
<<
tensor_trace
->
type
();
}
}
}
}
// namespace runtime
}
// namespace dragnn
}
// namespace syntaxnet
research/syntaxnet/dragnn/runtime/myelin/myelin_tracing.h
0 → 100644
View file @
a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#ifndef DRAGNN_RUNTIME_MYELIN_MYELIN_TRACING_H_
#define DRAGNN_RUNTIME_MYELIN_MYELIN_TRACING_H_
#include "dragnn/protos/cell_trace.pb.h"
#include "sling/myelin/compute.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
// Overwrites the |cell_trace| with traces extracted from the |instance|. Does
// not modify the |instance|; it is non-const because the relevant accessors are
// declared non-const.
void
TraceMyelinInstance
(
sling
::
myelin
::
Instance
*
instance
,
CellTrace
*
cell_trace
);
}
// namespace runtime
}
// namespace dragnn
}
// namespace syntaxnet
#endif // DRAGNN_RUNTIME_MYELIN_MYELIN_TRACING_H_
research/syntaxnet/dragnn/runtime/myelin/myelin_tracing_test.cc
0 → 100644
View file @
a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/myelin/myelin_tracing.h"
#include <string.h>
#include <string>
#include "dragnn/core/test/generic.h"
#include "dragnn/protos/cell_trace.pb.h"
#include "dragnn/runtime/myelin/myelin_spec_utils.h"
#include "dragnn/runtime/test/helpers.h"
#include "syntaxnet/base.h"
#include "sling/myelin/compute.h"
#include "sling/myelin/flow.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/dynamic_annotations.h"
#include "tensorflow/core/platform/test.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
namespace
{
// Name of the dummy cell for tests.
constexpr
char
kCellName
[]
=
"test_cell"
;
// Returns a CellTrace parsed from the concatenation of the |args|.
template
<
class
...
Args
>
CellTrace
ParseCellTrace
(
const
Args
&
...
args
)
{
const
string
text_proto
=
tensorflow
::
strings
::
StrCat
(
args
...);
CellTrace
cell_trace
;
CHECK
(
TextFormat
::
ParseFromString
(
text_proto
,
&
cell_trace
));
return
cell_trace
;
}
// Testing rig.
class
TraceMyelinInstanceTest
:
public
::
testing
::
Test
{
protected:
// Compiles the |flow_|, binds the name=>data |feeds|, evaluates the cell, and
// returns an extracted trace.
CellTrace
GetTrace
(
const
std
::
map
<
string
,
MutableAlignedView
>
&
feeds
)
{
sling
::
myelin
::
Library
library
;
RegisterMyelinLibraries
(
&
library
);
LOG
(
INFO
)
<<
"Original flow:
\n
"
<<
flow_
.
ToString
();
flow_
.
Analyze
(
library
);
LOG
(
INFO
)
<<
"Analyzed flow:
\n
"
<<
flow_
.
ToString
();
sling
::
myelin
::
Network
network
;
CHECK
(
network
.
Compile
(
flow_
,
library
));
const
sling
::
myelin
::
Cell
*
cell
=
network
.
GetCell
(
kCellName
);
CHECK
(
cell
!=
nullptr
)
<<
"Unknown cell: "
<<
kCellName
;
sling
::
myelin
::
Instance
instance
(
cell
);
for
(
const
auto
&
it
:
feeds
)
{
const
string
&
name
=
it
.
first
;
char
*
data
=
it
.
second
.
data
();
sling
::
myelin
::
Tensor
*
tensor
=
network
.
GetParameter
(
name
);
CHECK
(
tensor
!=
nullptr
)
<<
"Unknown tensor: "
<<
name
;
instance
.
SetReference
(
tensor
,
data
);
}
instance
.
Compute
();
CellTrace
cell_trace
;
TraceMyelinInstance
(
&
instance
,
&
cell_trace
);
return
cell_trace
;
}
// Flow, to be modified in each test.
sling
::
myelin
::
Flow
flow_
;
// The function to trace. Each test should add operations to this.
sling
::
myelin
::
Flow
::
Function
*
function_
=
flow_
.
AddFunction
(
kCellName
);
};
// Tests tracing on a simple cell with one operation. In this cell, both the
// input and output are Tensor refs and need to be fed.
TEST_F
(
TraceMyelinInstanceTest
,
SingleOperation
)
{
sling
::
myelin
::
Flow
::
Variable
*
input
=
flow_
.
AddVariable
(
"input"
,
sling
::
myelin
::
DT_FLOAT
,
{
1
});
input
->
in
=
true
;
input
->
ref
=
true
;
sling
::
myelin
::
Flow
::
Variable
*
one
=
flow_
.
AddVariable
(
"one"
,
sling
::
myelin
::
DT_FLOAT
,
{
1
});
constexpr
float
kOne
=
1.0
;
one
->
SetData
(
&
kOne
,
sizeof
(
float
));
sling
::
myelin
::
Flow
::
Variable
*
axis
=
flow_
.
AddVariable
(
"axis"
,
sling
::
myelin
::
DT_INT32
,
{
1
});
constexpr
int32
kAxis
=
0
;
axis
->
SetData
(
&
kAxis
,
sizeof
(
int32
));
sling
::
myelin
::
Flow
::
Variable
*
output
=
flow_
.
AddVariable
(
"output"
,
sling
::
myelin
::
DT_FLOAT
,
{
2
});
output
->
out
=
true
;
output
->
ref
=
true
;
sling
::
myelin
::
Flow
::
Operation
*
concat
=
flow_
.
AddOperation
(
function_
,
"concat"
,
"ConcatV2"
,
{
input
,
one
,
axis
},
{
output
});
concat
->
SetAttr
(
"N"
,
2
);
UniqueVector
<
float
>
input_feed
(
1
);
UniqueVector
<
float
>
output_feed
(
2
);
(
*
input_feed
)[
0
]
=
-
1.5
;
TF_ANNOTATE_MEMORY_IS_INITIALIZED
(
output_feed
->
data
(),
output_feed
->
size
()
*
sizeof
(
float
));
const
std
::
map
<
string
,
MutableAlignedView
>
feeds
=
{
{
"input"
,
input_feed
.
view
()},
//
{
"output"
,
output_feed
.
view
()}};
const
CellTrace
expected_trace
=
ParseCellTrace
(
R"(
name: ')"
,
kCellName
,
R"('
tensor {
name: 'input'
type: 'float32'
dimension: [1]
aligned_dimension: [1]
order: ORDER_ROW_MAJOR
value: [-1.5]
}
tensor {
name: 'output'
type: 'float32'
dimension: [2]
aligned_dimension: [2]
order: ORDER_ROW_MAJOR
value: [-1.5, 1.0]
}
operation {
name: 'concat'
type: 'ConcatV2'
kernel: 'BasicConcat'
input: ['input', 'one', 'axis']
output: ['output']
}
)"
);
EXPECT_THAT
(
GetTrace
(
feeds
),
test
::
EqualsProto
(
expected_trace
));
EXPECT_EQ
((
*
output_feed
)[
0
],
-
1.5
);
EXPECT_EQ
((
*
output_feed
)[
1
],
1.0
);
}
// Tests tracing on a slightly more complex cell with a few operations. In this
// case, only the input is a Tensor ref and needs to be fed.
TEST_F
(
TraceMyelinInstanceTest
,
MultiOperation
)
{
sling
::
myelin
::
Flow
::
Variable
*
input
=
flow_
.
AddVariable
(
"input"
,
sling
::
myelin
::
DT_FLOAT
,
{
1
});
input
->
in
=
true
;
input
->
ref
=
true
;
sling
::
myelin
::
Flow
::
Variable
*
one
=
flow_
.
AddVariable
(
"one"
,
sling
::
myelin
::
DT_FLOAT
,
{
1
});
constexpr
float
kOne
=
1.0
;
one
->
SetData
(
&
kOne
,
sizeof
(
float
));
sling
::
myelin
::
Flow
::
Variable
*
two
=
flow_
.
AddVariable
(
"two"
,
sling
::
myelin
::
DT_FLOAT
,
{
1
});
constexpr
float
kTwo
=
2.0
;
two
->
SetData
(
&
kTwo
,
sizeof
(
float
));
sling
::
myelin
::
Flow
::
Variable
*
three
=
flow_
.
AddVariable
(
"three"
,
sling
::
myelin
::
DT_FLOAT
,
{
1
});
constexpr
float
kThree
=
3.0
;
three
->
SetData
(
&
kThree
,
sizeof
(
float
));
sling
::
myelin
::
Flow
::
Variable
*
four
=
flow_
.
AddVariable
(
"four"
,
sling
::
myelin
::
DT_FLOAT
,
{
1
});
constexpr
float
kFour
=
4.0
;
four
->
SetData
(
&
kFour
,
sizeof
(
float
));
sling
::
myelin
::
Flow
::
Variable
*
axis
=
flow_
.
AddVariable
(
"axis"
,
sling
::
myelin
::
DT_INT32
,
{
1
});
constexpr
int32
kAxis
=
0
;
axis
->
SetData
(
&
kAxis
,
sizeof
(
int32
));
sling
::
myelin
::
Flow
::
Variable
*
local_1
=
flow_
.
AddVariable
(
"local_1"
,
sling
::
myelin
::
DT_FLOAT
,
{
3
});
sling
::
myelin
::
Flow
::
Variable
*
local_2
=
flow_
.
AddVariable
(
"local_2"
,
sling
::
myelin
::
DT_FLOAT
,
{
3
});
sling
::
myelin
::
Flow
::
Variable
*
output
=
flow_
.
AddVariable
(
"output"
,
sling
::
myelin
::
DT_FLOAT
,
{
6
});
output
->
out
=
true
;
sling
::
myelin
::
Flow
::
Operation
*
concat_1
=
flow_
.
AddOperation
(
function_
,
"concat_1"
,
"ConcatV2"
,
{
one
,
input
,
two
,
axis
},
{
local_1
});
concat_1
->
SetAttr
(
"N"
,
3
);
sling
::
myelin
::
Flow
::
Operation
*
concat_2
=
flow_
.
AddOperation
(
function_
,
"concat_2"
,
"ConcatV2"
,
{
three
,
four
,
input
,
axis
},
{
local_2
});
concat_2
->
SetAttr
(
"N"
,
3
);
sling
::
myelin
::
Flow
::
Operation
*
concat_3
=
flow_
.
AddOperation
(
function_
,
"concat_3"
,
"ConcatV2"
,
{
local_1
,
local_2
,
axis
},
{
output
});
concat_3
->
SetAttr
(
"N"
,
2
);
UniqueVector
<
float
>
input_feed
(
1
);
(
*
input_feed
)[
0
]
=
0.75
;
const
std
::
map
<
string
,
MutableAlignedView
>
feeds
=
{
{
"input"
,
input_feed
.
view
()}};
const
CellTrace
expected_trace
=
ParseCellTrace
(
R"(
name: ')"
,
kCellName
,
R"('
tensor {
name: 'input'
type: 'float32'
dimension: [1]
aligned_dimension: [1]
order: ORDER_ROW_MAJOR
value: [0.75]
}
tensor {
name: 'local_1'
type: 'float32'
dimension: [3]
aligned_dimension: [3]
order: ORDER_ROW_MAJOR
value: [1.0, 0.75, 2.0]
}
tensor {
name: 'local_2'
type: 'float32'
dimension: [3]
aligned_dimension: [3]
order: ORDER_ROW_MAJOR
value: [3.0, 4.0, 0.75]
}
tensor {
name: 'output'
type: 'float32'
dimension: [6]
aligned_dimension: [6]
order: ORDER_ROW_MAJOR
value: [1.0, 0.75, 2.0, 3.0, 4.0, 0.75]
}
operation {
name: 'concat_1'
type: 'ConcatV2'
kernel: 'BasicConcat'
input: ['one', 'input', 'two', 'axis']
output: ['local_1']
}
operation {
name: 'concat_2'
type: 'ConcatV2'
kernel: 'BasicConcat'
input: ['three', 'four', 'input', 'axis']
output: ['local_2']
}
operation {
name: 'concat_3'
type: 'ConcatV2'
kernel: 'BasicConcat'
input: ['local_1', 'local_2', 'axis']
output: ['output']
}
)"
);
EXPECT_THAT
(
GetTrace
(
feeds
),
test
::
EqualsProto
(
expected_trace
));
}
// Tests tracing on a flow that contains an unsupported type: complex128. In
// this case, the tensor values will be missing, but the rest of the trace is
// still extracted.
TEST_F
(
TraceMyelinInstanceTest
,
UnsupportedType
)
{
sling
::
myelin
::
Flow
::
Variable
*
input
=
flow_
.
AddVariable
(
"input"
,
sling
::
myelin
::
DT_COMPLEX128
,
{
1
});
input
->
in
=
true
;
input
->
ref
=
true
;
sling
::
myelin
::
Flow
::
Variable
*
zero
=
flow_
.
AddVariable
(
"zero"
,
sling
::
myelin
::
DT_COMPLEX128
,
{
1
});
const
std
::
vector
<
char
>
bytes
(
2
*
sizeof
(
uint64
));
zero
->
SetData
(
bytes
.
data
(),
bytes
.
size
());
sling
::
myelin
::
Flow
::
Variable
*
axis
=
flow_
.
AddVariable
(
"axis"
,
sling
::
myelin
::
DT_INT32
,
{
1
});
constexpr
int32
kAxis
=
0
;
axis
->
SetData
(
&
kAxis
,
sizeof
(
int32
));
sling
::
myelin
::
Flow
::
Variable
*
output
=
flow_
.
AddVariable
(
"output"
,
sling
::
myelin
::
DT_COMPLEX128
,
{
2
});
output
->
out
=
true
;
output
->
ref
=
true
;
sling
::
myelin
::
Flow
::
Operation
*
concat
=
flow_
.
AddOperation
(
function_
,
"concat"
,
"ConcatV2"
,
{
input
,
zero
,
axis
},
{
output
});
concat
->
SetAttr
(
"N"
,
2
);
// Both the input and output are refs and need to be fed.
UniqueVector
<
char
>
input_feed
(
2
*
sizeof
(
uint64
));
UniqueVector
<
char
>
output_feed
(
4
*
sizeof
(
uint64
));
const
std
::
map
<
string
,
MutableAlignedView
>
feeds
=
{
{
"input"
,
input_feed
.
view
()},
//
{
"output"
,
output_feed
.
view
()}};
memset
(
input_feed
->
data
(),
0
,
input_feed
->
size
());
const
CellTrace
expected_trace
=
ParseCellTrace
(
R"(
name: ')"
,
kCellName
,
R"('
tensor {
name: 'input'
type: 'complex128'
dimension: [1]
aligned_dimension: [1]
order: ORDER_ROW_MAJOR
}
tensor {
name: 'output'
type: 'complex128'
dimension: [2]
aligned_dimension: [2]
order: ORDER_ROW_MAJOR
}
operation {
name: 'concat'
type: 'ConcatV2'
kernel: 'BasicConcat'
input: ['input', 'zero', 'axis']
output: ['output']
}
)"
);
EXPECT_THAT
(
GetTrace
(
feeds
),
test
::
EqualsProto
(
expected_trace
));
}
}
// namespace
}
// namespace runtime
}
// namespace dragnn
}
// namespace syntaxnet
research/syntaxnet/dragnn/runtime/myelin/myelination.cc
0 → 100644
View file @
a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/myelin/myelination.h"
#include <algorithm>
#include <memory>
#include <utility>
#include <vector>
#include "dragnn/protos/spec.pb.h"
#include "dragnn/runtime/component.h"
#include "dragnn/runtime/myelin/myelin_cell_converter.h"
#include "dragnn/runtime/myelin/myelin_spec_utils.h"
#include "dragnn/runtime/trained_model.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/logging.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
namespace
{
// Updates the Component subclass in the |component_spec| to a Myelin-based
// version. On error, returns non-OK and modifies nothing.
tensorflow
::
Status
MyelinateComponentSubclass
(
ComponentSpec
*
component_spec
)
{
const
string
subclass
=
GetNormalizedComponentBuilderName
(
*
component_spec
);
if
(
subclass
!=
"DynamicComponent"
)
{
return
tensorflow
::
errors
::
Unimplemented
(
"No Myelin-based version of Component subclass '"
,
subclass
,
"'"
);
}
// By convention, the Myelin-based version of "FooComponent" should be named
// "MyelinFooComponent".
component_spec
->
mutable_component_builder
()
->
set_registered_name
(
tensorflow
::
strings
::
StrCat
(
"Myelin"
,
subclass
));
return
tensorflow
::
Status
::
OK
();
}
// Appends the list of component specs in the |master_spec| whose names match
// |component_names| to |matching_components|. On error, returns non-OK.
tensorflow
::
Status
GetMatchingComponentSpecs
(
const
std
::
set
<
string
>
&
component_names
,
MasterSpec
*
master_spec
,
std
::
vector
<
ComponentSpec
*>
*
matching_components
)
{
// Index the components in the |master_spec| by name.
std
::
map
<
string
,
ComponentSpec
*>
components
;
for
(
ComponentSpec
&
component_spec
:
*
master_spec
->
mutable_component
())
{
if
(
!
components
.
emplace
(
component_spec
.
name
(),
&
component_spec
).
second
)
{
return
tensorflow
::
errors
::
InvalidArgument
(
"Duplicate component name: "
,
component_spec
.
name
());
}
}
// Append the components named in the |component_names|.
for
(
const
string
&
component_name
:
component_names
)
{
if
(
components
.
find
(
component_name
)
==
components
.
end
())
{
return
tensorflow
::
errors
::
InvalidArgument
(
"Unknown component name: "
,
component_name
);
}
matching_components
->
push_back
(
components
[
component_name
]);
}
return
tensorflow
::
Status
::
OK
();
}
}
// namespace
tensorflow
::
Status
MyelinateCells
(
const
string
&
saved_model_dir
,
const
string
&
master_spec_path
,
const
std
::
set
<
string
>
&
component_names
,
const
string
&
output_dir
)
{
MasterSpec
master_spec
;
TF_RETURN_IF_ERROR
(
tensorflow
::
ReadTextProto
(
tensorflow
::
Env
::
Default
(),
master_spec_path
,
&
master_spec
));
std
::
vector
<
ComponentSpec
*>
components
;
TF_RETURN_IF_ERROR
(
GetMatchingComponentSpecs
(
component_names
,
&
master_spec
,
&
components
));
// Returns the path to the output Flow file for the |component_spec|.
const
auto
get_flow_path
=
[
&
](
const
ComponentSpec
&
component_spec
)
{
return
tensorflow
::
io
::
JoinPath
(
output_dir
,
tensorflow
::
strings
::
StrCat
(
component_spec
.
name
(),
".flow"
));
};
// Modify the MasterSpec first, to catch issues before loading the trained
// model, which is slow.
for
(
ComponentSpec
*
component_spec
:
components
)
{
// Add a resource for the Flow file to each component. The file will be
// created in a second pass, after loading the trained model.
TF_RETURN_IF_ERROR
(
AddMyelinFlowResource
(
get_flow_path
(
*
component_spec
),
component_spec
));
// Replace the Component subclass with a Myelin-based version.
TF_RETURN_IF_ERROR
(
MyelinateComponentSubclass
(
component_spec
));
// Set embedding_dim=-1 for all channels.
for
(
auto
&
fixed_channel
:
*
component_spec
->
mutable_fixed_feature
())
{
fixed_channel
.
set_embedding_dim
(
-
1
);
}
for
(
auto
&
linked_channel
:
*
component_spec
->
mutable_linked_feature
())
{
linked_channel
.
set_embedding_dim
(
-
1
);
}
}
// Write the updated MasterSpec.
TF_RETURN_IF_ERROR
(
tensorflow
::
Env
::
Default
()
->
RecursivelyCreateDir
(
output_dir
));
TF_RETURN_IF_ERROR
(
tensorflow
::
WriteTextProto
(
tensorflow
::
Env
::
Default
(),
tensorflow
::
io
::
JoinPath
(
output_dir
,
"master-spec"
),
master_spec
));
// Convert each component into a Flow and write it.
TrainedModel
trained_model
;
TF_RETURN_IF_ERROR
(
trained_model
.
Reset
(
saved_model_dir
));
for
(
const
ComponentSpec
*
component_spec
:
components
)
{
string
flow_data
;
TF_RETURN_IF_ERROR
(
MyelinCellConverter
::
Convert
(
component_spec
->
name
(),
trained_model
,
&
flow_data
));
TF_RETURN_IF_ERROR
(
tensorflow
::
WriteStringToFile
(
tensorflow
::
Env
::
Default
(),
get_flow_path
(
*
component_spec
),
flow_data
));
}
return
tensorflow
::
Status
::
OK
();
}
}
// namespace runtime
}
// namespace dragnn
}
// namespace syntaxnet
research/syntaxnet/dragnn/runtime/myelin/myelination.h
0 → 100644
View file @
a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
// Utils for modifying pre-trained models to use Myelin.
#ifndef DRAGNN_RUNTIME_MYELIN_MYELINATION_H_
#define DRAGNN_RUNTIME_MYELIN_MYELINATION_H_
#include <set>
#include <string>
#include "syntaxnet/base.h"
#include "tensorflow/core/lib/core/status.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
// Modifies a DRAGNN model to use Myelin.
//
// Loads a TF SavedModel from the |saved_model_dir| and a text-format MasterSpec
// from the |master_spec_path|. Converts each component in |component_names|
// into a Myelin Flow (see myelin_cell_converter.h) and writes the results to
// the |output_dir| as files "<output_dir>/<component_name>.flow". Modifies the
// relevant ComponentSpecs in the MasterSpec to use Myelin as described below,
// and writes it to "<output_dir>/master-spec".
//
// MasterSpec modifications:
// * Adds a resource to each ComponentSpec that points at the relevant Flow file
// in the |output_dir|.
// * Replaces the Component subclass specified in each ComponentSpec with the
// Myelin-based equivalent, which should be named "Myelin<subclass_name>";
// e.g., MyelinDynamicComponent.
// * Sets FixedFeatureChannel.embedding_dim to -1 in all channels, because
// Myelin takes feature IDs as input instead of fixed embedding sums.
// * Sets LinkedFeatureChannel.embedding_dim to -1 in all channels, because
// Myelin handles the linked embedding matrix multiplication (if any) and
// always takes the original activation vector as input.
//
// On error, returns non-OK. Possible errors include:
// * Any file I/O or proto parsing error.
// * The MasterSpec has a duplicate component name.
// * One of the |component_names| does not match anything in the MasterSpec.
// * The MasterSpec already has Myelin Flow resources.
// * One of the components is not supported by Myelin.
// * Error raised by MyelinCellConverter during conversion.
//
// Side note: This function has a file-path-based API so it can be easily
// wrapped in a stand-alone binary.
tensorflow
::
Status
MyelinateCells
(
const
string
&
saved_model_dir
,
const
string
&
master_spec_path
,
const
std
::
set
<
string
>
&
component_names
,
const
string
&
output_dir
);
}
// namespace runtime
}
// namespace dragnn
}
// namespace syntaxnet
#endif // DRAGNN_RUNTIME_MYELIN_MYELINATION_H_
research/syntaxnet/dragnn/runtime/myelin/myelination_test.cc
0 → 100644
View file @
a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/myelin/myelination.h"
#include <memory>
#include <string>
#include <utility>
#include "dragnn/core/test/generic.h"
#include "dragnn/protos/spec.pb.h"
#include "dragnn/runtime/myelin/myelin_spec_utils.h"
#include "syntaxnet/base.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/test.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
namespace
{
// Arbitrary bogus path.
constexpr
char
kInvalidPath
[]
=
"path/to/some/invalid/file"
;
// Relative path to a MasterSpec.
constexpr
char
kMasterSpecPath
[]
=
"dragnn/runtime/testdata/rnn_tagger/assets.extra/master_spec"
;
// Relative path to a saved model.
constexpr
char
kSavedModelDir
[]
=
"dragnn/runtime/testdata/rnn_tagger"
;
// Relative path to a directory containing expected output.
constexpr
char
kExpectedOutputDir
[]
=
"dragnn/runtime/myelin/testdata/myelination_output"
;
// Local relative path to the expected output directory.
constexpr
char
kLocalOutputDir
[]
=
"dragnn/runtime/myelin/testdata/myelination_output"
;
// Returns the set of components in the MasterSpec at |kMasterSpecPath|.
std
::
set
<
string
>
GetComponentNames
()
{
return
{
"rnn"
,
"tagger"
};
}
// Returns the path to a test input denoted by the |relative_path|.
string
GetInput
(
const
string
&
relative_path
)
{
return
tensorflow
::
io
::
JoinPath
(
test
::
GetTestDataPrefix
(),
relative_path
);
}
// Returns a unique output directory for tests.
string
GetUniqueOutputDir
()
{
static
int
counter
=
0
;
return
tensorflow
::
io
::
JoinPath
(
tensorflow
::
testing
::
TmpDir
(),
tensorflow
::
strings
::
StrCat
(
"output_"
,
counter
++
));
}
// Compares the content of the file named |basename| in the |actual_output_dir|
// with the file with the same |basename| in |kExpectedOutputDir|. Can also be
// modified to write the actual file content to |kLocalOutputDir|, for updating
// test expectations.
void
CompareOrRewriteTestData
(
const
string
&
actual_output_dir
,
const
string
&
basename
)
{
string
actual_data
;
TF_ASSERT_OK
(
tensorflow
::
ReadFileToString
(
tensorflow
::
Env
::
Default
(),
tensorflow
::
io
::
JoinPath
(
actual_output_dir
,
basename
),
&
actual_data
));
if
(
false
)
{
TF_ASSERT_OK
(
tensorflow
::
WriteStringToFile
(
tensorflow
::
Env
::
Default
(),
tensorflow
::
io
::
JoinPath
(
kLocalOutputDir
,
basename
),
actual_data
));
}
else
{
string
expected_data
;
TF_ASSERT_OK
(
tensorflow
::
ReadFileToString
(
tensorflow
::
Env
::
Default
(),
GetInput
(
tensorflow
::
io
::
JoinPath
(
kExpectedOutputDir
,
basename
)),
&
expected_data
));
// Avoid EXPECT_EQ(), which produces a text diff on error. The diff is not
// interpretable because Flow files are binary, and the test can OOM when it
// tries to diff two large binary files.
EXPECT_TRUE
(
actual_data
==
expected_data
);
}
}
// Reads a text-format MasterSpec from the |master_spec_path|, clears resource
// file patterns, and writes it back to the |master_spec_path|. The resource
// file patterns would otherwise cause spurious mismatches.
void
ClearResourceFilePatterns
(
const
string
&
master_spec_path
)
{
MasterSpec
master_spec
;
TF_ASSERT_OK
(
tensorflow
::
ReadTextProto
(
tensorflow
::
Env
::
Default
(),
master_spec_path
,
&
master_spec
));
for
(
ComponentSpec
&
component_spec
:
*
master_spec
.
mutable_component
())
{
for
(
Resource
&
resource
:
*
component_spec
.
mutable_resource
())
{
for
(
Part
&
part
:
*
resource
.
mutable_part
())
{
part
.
clear_file_pattern
();
}
}
}
TF_ASSERT_OK
(
tensorflow
::
WriteTextProto
(
tensorflow
::
Env
::
Default
(),
master_spec_path
,
master_spec
));
}
// Tests that MyelinateCells() fails if the saved model is invalid.
TEST
(
MyelinateCellsTest
,
InvalidSavedModel
)
{
EXPECT_FALSE
(
MyelinateCells
(
kInvalidPath
,
GetInput
(
kMasterSpecPath
),
{},
GetUniqueOutputDir
())
.
ok
());
}
// Tests that MyelinateCells() fails if the master spec is invalid.
TEST
(
MyelinateCellsTest
,
InvalidMasterSpec
)
{
EXPECT_FALSE
(
MyelinateCells
(
GetInput
(
kSavedModelDir
),
kInvalidPath
,
{},
GetUniqueOutputDir
())
.
ok
());
}
// Tests that MyelinateCells() fails if the MasterSpec contains a duplicate
// component.
TEST
(
MyelinateCellsTest
,
DuplicateComponent
)
{
const
string
kSpec
=
"component { name:'foo' } component { name:'foo' }"
;
const
string
master_spec_path
=
tensorflow
::
io
::
JoinPath
(
tensorflow
::
testing
::
TmpDir
(),
"master-spec-with-duplicate"
);
TF_ASSERT_OK
(
tensorflow
::
WriteStringToFile
(
tensorflow
::
Env
::
Default
(),
master_spec_path
,
kSpec
));
EXPECT_THAT
(
MyelinateCells
(
GetInput
(
kSavedModelDir
),
master_spec_path
,
{},
GetUniqueOutputDir
()),
test
::
IsErrorWithSubstr
(
"Duplicate component name: foo"
));
}
// Tests that MyelinateCells() fails if one of the requested components does not
// appear in the MasterSpec.
TEST
(
MyelinateCellsTest
,
FilterWithUnknownComponent
)
{
const
string
kSpec
=
"component { name:'foo' } component { name:'bar' }"
;
const
string
master_spec_path
=
tensorflow
::
io
::
JoinPath
(
tensorflow
::
testing
::
TmpDir
(),
"master-spec-foo-bar"
);
TF_ASSERT_OK
(
tensorflow
::
WriteStringToFile
(
tensorflow
::
Env
::
Default
(),
master_spec_path
,
kSpec
));
EXPECT_THAT
(
MyelinateCells
(
GetInput
(
kSavedModelDir
),
master_spec_path
,
{
"missing"
},
GetUniqueOutputDir
()),
test
::
IsErrorWithSubstr
(
"Unknown component name: missing"
));
}
// Tests that MyelinateCells() fails if a component already has a Myelin Flow.
TEST
(
MyelinateCellsTest
,
AlreadyHasFlow
)
{
const
string
kSpec
=
tensorflow
::
strings
::
StrCat
(
"component { name: 'foo' resource { name: '"
,
kMyelinFlowResourceName
,
"' } }"
);
const
string
master_spec_path
=
tensorflow
::
io
::
JoinPath
(
tensorflow
::
testing
::
TmpDir
(),
"master-spec-with-flows"
);
TF_ASSERT_OK
(
tensorflow
::
WriteStringToFile
(
tensorflow
::
Env
::
Default
(),
master_spec_path
,
kSpec
));
EXPECT_THAT
(
MyelinateCells
(
GetInput
(
kSavedModelDir
),
master_spec_path
,
{
"foo"
},
GetUniqueOutputDir
()),
test
::
IsErrorWithSubstr
(
"already contains a Myelin Flow resource"
));
}
// Tests that MyelinateCells() fails on the wrong Component type.
TEST
(
MyelinateCellsTest
,
WrongComponentType
)
{
const
string
kSpec
=
"component { name: 'foo' component_builder { registered_name: "
"'WrongComponent' } }"
;
const
string
master_spec_path
=
tensorflow
::
io
::
JoinPath
(
tensorflow
::
testing
::
TmpDir
(),
"master-spec"
);
TF_ASSERT_OK
(
tensorflow
::
WriteStringToFile
(
tensorflow
::
Env
::
Default
(),
master_spec_path
,
kSpec
));
EXPECT_THAT
(
MyelinateCells
(
GetInput
(
kSavedModelDir
),
master_spec_path
,
{
"foo"
},
GetUniqueOutputDir
()),
test
::
IsErrorWithSubstr
(
"No Myelin-based version of Component subclass 'WrongComponent'"
));
}
// Tests that MyelinateCells() succeeds on the pre-trained inputs and reproduces
// expected outputs.
TEST
(
MyelinateCellsTest
,
RegressionTest
)
{
const
string
output_dir
=
GetUniqueOutputDir
();
TF_ASSERT_OK
(
MyelinateCells
(
GetInput
(
kSavedModelDir
),
GetInput
(
kMasterSpecPath
),
GetComponentNames
(),
output_dir
));
ClearResourceFilePatterns
(
tensorflow
::
io
::
JoinPath
(
output_dir
,
"master-spec"
));
CompareOrRewriteTestData
(
output_dir
,
"master-spec"
);
for
(
const
string
&
component_name
:
GetComponentNames
())
{
const
string
flow_basename
=
tensorflow
::
strings
::
StrCat
(
component_name
,
".flow"
);
CompareOrRewriteTestData
(
output_dir
,
flow_basename
);
}
}
}
// namespace
}
// namespace runtime
}
// namespace dragnn
}
// namespace syntaxnet
research/syntaxnet/dragnn/runtime/myelin/sequence_myelin_dynamic_component.cc
0 → 100644
View file @
a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include <stddef.h>
#include <string>
#include <vector>
#include "dragnn/core/compute_session.h"
#include "dragnn/protos/spec.pb.h"
#include "dragnn/protos/trace.pb.h"
#include "dragnn/runtime/component.h"
#include "dragnn/runtime/extensions.h"
#include "dragnn/runtime/math/types.h"
#include "dragnn/runtime/myelin/myelin_dynamic_component_base.h"
#include "dragnn/runtime/network_states.h"
#include "dragnn/runtime/sequence_features.h"
#include "dragnn/runtime/sequence_links.h"
#include "dragnn/runtime/sequence_model.h"
#include "dragnn/runtime/session_state.h"
#include "dragnn/runtime/variable_store.h"
#include "syntaxnet/base.h"
#include "sling/myelin/compute.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
namespace
{
// A Myelin-based version of DynamicComponent for sequence-based models.
class
SequenceMyelinDynamicComponent
:
public
MyelinDynamicComponentBase
{
public:
// Implements Component.
tensorflow
::
Status
Initialize
(
const
ComponentSpec
&
component_spec
,
VariableStore
*
variable_store
,
NetworkStateManager
*
network_state_manager
,
ExtensionManager
*
extension_manager
)
override
;
tensorflow
::
Status
Evaluate
(
SessionState
*
session_state
,
ComputeSession
*
compute_session
,
ComponentTrace
*
component_trace
)
const
override
;
protected:
// Implements Component.
bool
Supports
(
const
ComponentSpec
&
component_spec
,
const
string
&
normalized_builder_name
)
const
override
;
bool
PreferredTo
(
const
Component
&
)
const
override
{
return
false
;
}
private:
// Binds the fixed feature IDs for the |target_index|'th element of the
// |features| to the |instance|. Uses locals in the |network_states|.
void
BindInputIds
(
const
SequenceFeatures
&
features
,
int
target_index
,
const
NetworkStates
&
network_states
,
sling
::
myelin
::
Instance
*
instance
)
const
;
// Binds the linked embeddings for the |target_index|'th element in the
// |links| to the |instance|.
void
BindInputLinks
(
const
SequenceLinks
&
links
,
int
target_index
,
sling
::
myelin
::
Instance
*
instance
)
const
;
// Sequence-based model evaluator.
SequenceModel
sequence_model_
;
// Intermediate values used by sequence models.
SharedExtensionHandle
<
SequenceModel
::
EvaluateState
>
evaluate_state_handle_
;
};
bool
SequenceMyelinDynamicComponent
::
Supports
(
const
ComponentSpec
&
component_spec
,
const
string
&
normalized_builder_name
)
const
{
return
normalized_builder_name
==
"SequenceMyelinDynamicComponent"
&&
SequenceModel
::
Supports
(
component_spec
);
}
tensorflow
::
Status
SequenceMyelinDynamicComponent
::
Initialize
(
const
ComponentSpec
&
component_spec
,
VariableStore
*
variable_store
,
NetworkStateManager
*
network_state_manager
,
ExtensionManager
*
extension_manager
)
{
// Initialize the base class first, so its FixedEmbeddingManager and
// LinkedEmbeddingManager can be wrapped in sequence-based versions.
TF_RETURN_IF_ERROR
(
MyelinDynamicComponentBase
::
Initialize
(
component_spec
,
variable_store
,
network_state_manager
,
extension_manager
));
TF_RETURN_IF_ERROR
(
sequence_model_
.
Initialize
(
component_spec
,
kLogitsName
,
&
fixed_embedding_manager
(),
&
linked_embedding_manager
(),
network_state_manager
));
extension_manager
->
GetShared
(
&
evaluate_state_handle_
);
return
tensorflow
::
Status
::
OK
();
}
void
SequenceMyelinDynamicComponent
::
BindInputIds
(
const
SequenceFeatures
&
features
,
int
target_index
,
const
NetworkStates
&
network_states
,
sling
::
myelin
::
Instance
*
instance
)
const
{
for
(
size_t
channel_id
=
0
;
channel_id
<
features
.
num_channels
();
++
channel_id
)
{
const
MutableVector
<
int32
>
id_vector
=
network_states
.
GetLocal
(
fixed_embedding_manager
().
id_handle
(
channel_id
,
0
));
id_vector
[
0
]
=
features
.
GetId
(
channel_id
,
target_index
);
BindInput
(
Vector
<
int32
>
(
id_vector
),
input_ids
()[
channel_id
].
id
,
instance
);
}
}
void
SequenceMyelinDynamicComponent
::
BindInputLinks
(
const
SequenceLinks
&
links
,
int
target_index
,
sling
::
myelin
::
Instance
*
instance
)
const
{
Vector
<
float
>
embedding
;
bool
is_out_of_bounds
=
false
;
for
(
size_t
channel_id
=
0
;
channel_id
<
links
.
num_channels
();
++
channel_id
)
{
links
.
Get
(
channel_id
,
target_index
,
&
embedding
,
&
is_out_of_bounds
);
BindInputLink
(
embedding
,
is_out_of_bounds
,
input_links
()[
channel_id
],
instance
);
}
}
tensorflow
::
Status
SequenceMyelinDynamicComponent
::
Evaluate
(
SessionState
*
session_state
,
ComputeSession
*
compute_session
,
ComponentTrace
*
component_trace
)
const
{
NetworkStates
&
network_states
=
session_state
->
network_states
;
SequenceModel
::
EvaluateState
&
state
=
session_state
->
extensions
.
Get
(
evaluate_state_handle_
);
TF_RETURN_IF_ERROR
(
sequence_model_
.
Preprocess
(
session_state
,
compute_session
,
&
state
));
// Avoid ComputeSession overhead by directly iterating over the feature IDs.
// Handle forward and reverse iteration via an index and increment.
int
target_index
=
sequence_model_
.
left_to_right
()
?
0
:
state
.
num_steps
-
1
;
const
int
target_increment
=
sequence_model_
.
left_to_right
()
?
1
:
-
1
;
sling
::
myelin
::
Instance
&
instance
=
GetInstance
(
session_state
);
for
(
size_t
step_index
=
0
;
step_index
<
state
.
num_steps
;
++
step_index
,
target_index
+=
target_increment
)
{
// Bind inputs and outputs into the |instance|.
BindInputIds
(
state
.
features
,
target_index
,
network_states
,
&
instance
);
BindInputLinks
(
state
.
links
,
target_index
,
&
instance
);
BindInputRecurrences
(
step_index
,
network_states
,
&
instance
);
BindOutputLayers
(
step_index
,
network_states
,
&
instance
);
// Invoke the cell in the |instance|.
instance
.
Compute
();
MaybeTrace
(
step_index
,
&
instance
,
component_trace
);
}
return
sequence_model_
.
Predict
(
network_states
,
&
state
);
}
DRAGNN_RUNTIME_REGISTER_COMPONENT
(
SequenceMyelinDynamicComponent
);
}
// namespace
}
// namespace runtime
}
// namespace dragnn
}
// namespace syntaxnet
research/syntaxnet/dragnn/runtime/myelin/sequence_myelin_dynamic_component_test.cc
0 → 100644
View file @
a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include <memory>
#include <string>
#include <vector>
#include "dragnn/core/compute_session.h"
#include "dragnn/core/input_batch_cache.h"
#include "dragnn/core/test/generic.h"
#include "dragnn/protos/spec.pb.h"
#include "dragnn/protos/trace.pb.h"
#include "dragnn/runtime/component.h"
#include "dragnn/runtime/extensions.h"
#include "dragnn/runtime/math/types.h"
#include "dragnn/runtime/myelin/myelin_spec_utils.h"
#include "dragnn/runtime/network_states.h"
#include "dragnn/runtime/sequence_backend.h"
#include "dragnn/runtime/sequence_extractor.h"
#include "dragnn/runtime/sequence_linker.h"
#include "dragnn/runtime/sequence_predictor.h"
#include "dragnn/runtime/session_state.h"
#include "dragnn/runtime/test/network_test_base.h"
#include "dragnn/runtime/variable_store.h"
#include "syntaxnet/base.h"
#include <gmock/gmock.h>
#include "sling/file/file.h"
#include "sling/myelin/flow.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/platform/test.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
namespace
{
using
::
testing
::
Return
;
constexpr
int
kFlowVersion
=
4
;
constexpr
int
kNumSteps
=
50
;
constexpr
int
kVocabularySize
=
123
;
constexpr
int
kFixedDim
=
6
;
constexpr
int
kLinkedDim
=
4
;
constexpr
int
kLogitsDim
=
kFixedDim
+
kLinkedDim
;
constexpr
char
kLogitsName
[]
=
"logits"
;
constexpr
char
kPreviousComponentName
[]
=
"previous_component"
;
constexpr
char
kPreviousLayerName
[]
=
"previous_layer"
;
constexpr
float
kPreviousLayerValue
=
-
1.0
;
// Builds and writes a simple Flow file with a function named |function_name|
// that gathers the rows of a matrix, concatenates that with a linked embedding,
// and outputs the result as the classification logits. Each row is filled with
// its index, so we can infer which indices were gathered.
string
WriteFlowFile
(
const
string
&
function_name
)
{
sling
::
myelin
::
Flow
flow
;
// A fixed feature ID input.
sling
::
myelin
::
Flow
::
Variable
*
id
=
flow
.
AddVariable
(
"id"
,
sling
::
myelin
::
DT_INT32
,
{
1
});
id
->
ref
=
true
;
id
->
aliases
.
push_back
(
MakeMyelinInputFixedFeatureIdName
(
0
,
0
));
// A linked feature embedding input.
sling
::
myelin
::
Flow
::
Variable
*
link
=
flow
.
AddVariable
(
"link"
,
sling
::
myelin
::
DT_FLOAT
,
{
1
,
kLinkedDim
});
link
->
ref
=
true
;
link
->
aliases
.
push_back
(
MakeMyelinInputLinkedActivationVectorName
(
0
));
// An embedding matrix constant. Each embedding is filled with its index.
sling
::
myelin
::
Flow
::
Variable
*
embeddings
=
flow
.
AddVariable
(
"embeddings"
,
sling
::
myelin
::
DT_FLOAT
,
{
kVocabularySize
,
kFixedDim
});
std
::
vector
<
float
>
data
(
kVocabularySize
*
kLogitsDim
);
for
(
int
row
=
0
;
row
<
kVocabularySize
;
++
row
)
{
for
(
int
column
=
0
;
column
<
kFixedDim
;
++
column
)
{
data
[
row
*
kFixedDim
+
column
]
=
row
;
}
}
embeddings
->
SetData
(
data
.
data
(),
data
.
size
()
*
sizeof
(
float
));
// The retrieved embedding row.
sling
::
myelin
::
Flow
::
Variable
*
row
=
flow
.
AddVariable
(
"row"
,
sling
::
myelin
::
DT_FLOAT
,
{
1
,
kFixedDim
});
// A concatenation axis constant.
sling
::
myelin
::
Flow
::
Variable
*
axis
=
flow
.
AddVariable
(
"axis"
,
sling
::
myelin
::
DT_INT32
,
{
1
});
const
int32
axis_value
=
1
;
axis
->
SetData
(
&
axis_value
,
sizeof
(
int32
));
// The classification logits output.
sling
::
myelin
::
Flow
::
Variable
*
logits
=
flow
.
AddVariable
(
kLogitsName
,
sling
::
myelin
::
DT_FLOAT
,
{
1
,
kLogitsDim
});
logits
->
ref
=
true
;
logits
->
aliases
.
push_back
(
MakeMyelinOutputLayerName
(
kLogitsName
));
// Function that contains the ops and variables.
sling
::
myelin
::
Flow
::
Function
*
function
=
flow
.
AddFunction
(
function_name
);
// A Gather op that looks up the |id| in the |embeddings|, and returns the
// result in the |row|.
flow
.
AddOperation
(
function
,
"gather"
,
"Gather"
,
{
embeddings
,
id
},
{
row
});
// A Concat op that concatenates the |row| and |link| along the |axis|,
// placing the result in the |logits| output.
flow
.
AddOperation
(
function
,
"concat"
,
"ConcatV2"
,
{
row
,
link
,
axis
},
{
logits
});
const
string
flow_path
=
tensorflow
::
io
::
JoinPath
(
tensorflow
::
testing
::
TmpDir
(),
"foo.flow"
);
sling
::
File
::
Init
();
flow
.
Save
(
flow_path
,
kFlowVersion
);
return
flow_path
;
}
// Sequence extractor that extracts [0, 2, 4, ...].
class
EvenNumbers
:
public
SequenceExtractor
{
public:
// Implements SequenceExtractor.
bool
Supports
(
const
FixedFeatureChannel
&
,
const
ComponentSpec
&
)
const
override
{
return
true
;
}
tensorflow
::
Status
Initialize
(
const
FixedFeatureChannel
&
,
const
ComponentSpec
&
)
override
{
return
tensorflow
::
Status
::
OK
();
}
tensorflow
::
Status
GetIds
(
InputBatchCache
*
,
std
::
vector
<
int32
>
*
ids
)
const
override
{
ids
->
clear
();
for
(
int
i
=
0
;
i
<
num_steps_
;
++
i
)
ids
->
push_back
(
2
*
i
);
return
tensorflow
::
Status
::
OK
();
}
// Sets the number of steps to emit.
static
void
SetNumSteps
(
int
num_steps
)
{
num_steps_
=
num_steps
;
}
private:
// The number of steps to produce.
static
int
num_steps_
;
};
int
EvenNumbers
::
num_steps_
=
kNumSteps
;
DRAGNN_RUNTIME_REGISTER_SEQUENCE_EXTRACTOR
(
EvenNumbers
);
// Component that supports a particular component name and is not preferred.
// Used to exercise PreferredTo().
class
NotPreferred
:
public
Component
{
public:
// Implements Component.
tensorflow
::
Status
Initialize
(
const
ComponentSpec
&
,
VariableStore
*
,
NetworkStateManager
*
,
ExtensionManager
*
)
override
{
return
tensorflow
::
Status
::
OK
();
}
tensorflow
::
Status
Evaluate
(
SessionState
*
,
ComputeSession
*
,
ComponentTrace
*
)
const
override
{
return
tensorflow
::
Status
::
OK
();
}
bool
Supports
(
const
ComponentSpec
&
spec
,
const
string
&
)
const
override
{
return
spec
.
name
()
==
"InSupportsConflictTest"
;
}
bool
PreferredTo
(
const
Component
&
)
const
override
{
return
false
;
}
};
DRAGNN_RUNTIME_REGISTER_COMPONENT
(
NotPreferred
);
// Trivial linker that links everything to step 0.
class
LinkToZero
:
public
SequenceLinker
{
public:
// Implements SequenceLinker.
bool
Supports
(
const
LinkedFeatureChannel
&
,
const
ComponentSpec
&
)
const
override
{
return
true
;
}
tensorflow
::
Status
Initialize
(
const
LinkedFeatureChannel
&
,
const
ComponentSpec
&
)
override
{
return
tensorflow
::
Status
::
OK
();
}
tensorflow
::
Status
GetLinks
(
size_t
,
InputBatchCache
*
,
std
::
vector
<
int32
>
*
links
)
const
override
{
links
->
assign
(
num_steps_
,
0
);
return
tensorflow
::
Status
::
OK
();
}
// Sets the number of steps to emit.
static
void
SetNumSteps
(
int
num_steps
)
{
num_steps_
=
num_steps
;
}
private:
// The number of steps to produce.
static
int
num_steps_
;
};
int
LinkToZero
::
num_steps_
=
kNumSteps
;
DRAGNN_RUNTIME_REGISTER_SEQUENCE_LINKER
(
LinkToZero
);
// Trivial predictor that captures the prediction logits.
class
CaptureLogits
:
public
SequencePredictor
{
public:
// Implements SequenceLinker.
bool
Supports
(
const
ComponentSpec
&
)
const
override
{
return
true
;
}
tensorflow
::
Status
Initialize
(
const
ComponentSpec
&
)
override
{
return
tensorflow
::
Status
::
OK
();
}
tensorflow
::
Status
Predict
(
Matrix
<
float
>
logits
,
InputBatchCache
*
)
const
override
{
GetLogits
()
=
logits
;
return
tensorflow
::
Status
::
OK
();
}
// Returns the captured logits.
static
Matrix
<
float
>
&
GetLogits
()
{
static
auto
*
logits
=
new
Matrix
<
float
>
();
return
*
logits
;
}
};
DRAGNN_RUNTIME_REGISTER_SEQUENCE_PREDICTOR
(
CaptureLogits
);
class
SequenceMyelinDynamicComponentTest
:
public
NetworkTestBase
{
protected:
// Adds default call expectations. Since these are added first, they can be
// overridden by call expectations in individual tests.
SequenceMyelinDynamicComponentTest
()
{
EXPECT_CALL
(
compute_session_
,
GetInputBatchCache
())
.
WillRepeatedly
(
Return
(
&
input_
));
EXPECT_CALL
(
compute_session_
,
GetReadiedComponent
(
kTestComponentName
))
.
WillRepeatedly
(
Return
(
&
backend_
));
TF_CHECK_OK
(
Component
::
CreateOrError
(
"SequenceMyelinDynamicComponent"
,
&
component_
));
// Some tests overwrite these; ensure that they are restored to the normal
// values at the start of each test.
EvenNumbers
::
SetNumSteps
(
kNumSteps
);
LinkToZero
::
SetNumSteps
(
kNumSteps
);
CaptureLogits
::
GetLogits
()
=
Matrix
<
float
>
();
}
// Build and write the flow file once.
static
void
SetUpTestCase
()
{
flow_path_
=
new
string
(
WriteFlowFile
(
kTestComponentName
));
}
// Cleans up the flow file path.
static
void
TearDownTestCase
()
{
delete
flow_path_
;
flow_path_
=
nullptr
;
}
// Creates a component, initializes it based on the |component_spec|, and
// evaluates it. On error, returns non-OK.
tensorflow
::
Status
Run
(
ComponentSpec
component_spec
)
{
component_spec
.
set_name
(
kTestComponentName
);
TF_RETURN_IF_ERROR
(
AddMyelinFlowResource
(
*
flow_path_
,
&
component_spec
));
AddComponent
(
kPreviousComponentName
);
AddLayer
(
kPreviousLayerName
,
kLinkedDim
);
AddComponent
(
kTestComponentName
);
TF_RETURN_IF_ERROR
(
component_
->
Initialize
(
component_spec
,
&
variable_store_
,
&
network_state_manager_
,
&
extension_manager_
));
network_states_
.
Reset
(
&
network_state_manager_
);
session_state_
.
extensions
.
Reset
(
&
extension_manager_
);
StartComponent
(
kNumSteps
);
FillLayer
(
kPreviousComponentName
,
kPreviousLayerName
,
kPreviousLayerValue
);
StartComponent
(
0
);
TF_RETURN_IF_ERROR
(
component_
->
Evaluate
(
&
session_state_
,
&
compute_session_
,
nullptr
));
return
tensorflow
::
Status
::
OK
();
}
// Returns the sequence size passed to the |backend_|.
int
GetBackendSequenceSize
()
{
// The sequence size is not directly exposed, but can be inferred using one
// of the reverse step translators.
return
backend_
.
GetStepLookupFunction
(
"reverse-token"
)(
0
,
0
,
0
)
+
1
;
}
// Path to a simple Myelin Flow file.
static
const
string
*
flow_path_
;
// Component used in the test.
std
::
unique_ptr
<
Component
>
component_
;
// Input batch injected into Evaluate() by default.
InputBatchCache
input_
;
// Backend injected into Evaluate().
SequenceBackend
backend_
;
};
const
string
*
SequenceMyelinDynamicComponentTest
::
flow_path_
=
nullptr
;
// Returns a ComponentSpec that is supported.
ComponentSpec
MakeSupportedSpec
()
{
ComponentSpec
component_spec
;
component_spec
.
set_num_actions
(
kLogitsDim
);
component_spec
.
mutable_component_builder
()
->
set_registered_name
(
"SequenceMyelinDynamicComponent"
);
component_spec
.
mutable_component_builder
()
->
mutable_parameters
()
->
insert
(
{
"sequence_extractors"
,
"EvenNumbers"
});
component_spec
.
mutable_component_builder
()
->
mutable_parameters
()
->
insert
(
{
"sequence_linkers"
,
"LinkToZero"
});
component_spec
.
mutable_component_builder
()
->
mutable_parameters
()
->
insert
(
{
"sequence_predictor"
,
"CaptureLogits"
});
component_spec
.
mutable_backend
()
->
set_registered_name
(
"SequenceBackend"
);
FixedFeatureChannel
*
fixed_feature
=
component_spec
.
add_fixed_feature
();
fixed_feature
->
set_size
(
1
);
fixed_feature
->
set_embedding_dim
(
-
1
);
LinkedFeatureChannel
*
linked_feature
=
component_spec
.
add_linked_feature
();
linked_feature
->
set_source_component
(
kPreviousComponentName
);
linked_feature
->
set_source_layer
(
kPreviousLayerName
);
linked_feature
->
set_size
(
1
);
linked_feature
->
set_embedding_dim
(
-
1
);
return
component_spec
;
}
// Tests that the component supports a supported spec.
TEST_F
(
SequenceMyelinDynamicComponentTest
,
Supported
)
{
string
component_type
;
const
ComponentSpec
component_spec
=
MakeSupportedSpec
();
TF_ASSERT_OK
(
Component
::
Select
(
component_spec
,
&
component_type
));
}
// Tests that the component does not support a spec with the wrong component
// builder.
TEST_F
(
SequenceMyelinDynamicComponentTest
,
UnsupportedComponentBuilder
)
{
string
component_type
;
ComponentSpec
component_spec
=
MakeSupportedSpec
();
component_spec
.
mutable_component_builder
()
->
set_registered_name
(
"bad"
);
EXPECT_THAT
(
Component
::
Select
(
component_spec
,
&
component_type
),
test
::
IsErrorWithSubstr
(
"Could not find a best"
));
}
// Tests that the component
TEST_F
(
SequenceMyelinDynamicComponentTest
,
SupportsConflict
)
{
string
component_type
;
ComponentSpec
component_spec
=
MakeSupportedSpec
();
component_spec
.
set_name
(
"InSupportsConflictTest"
);
// see NotPreferred
EXPECT_THAT
(
Component
::
Select
(
component_spec
,
&
component_type
),
test
::
IsErrorWithSubstr
(
"both think they should be dis-preferred"
));
}
// Asserts that the vector starts with |kFixedDim| copies of |value| and ends
// with |kLinkedDim| copies of |kPreviousLayerValue|.
void
AssertOutputRow
(
Vector
<
float
>
row
,
float
value
)
{
ASSERT_EQ
(
row
.
size
(),
kLogitsDim
);
for
(
int
i
=
0
;
i
<
row
.
size
();
++
i
)
{
if
(
i
<
kFixedDim
)
{
ASSERT_EQ
(
row
[
i
],
value
);
}
else
{
ASSERT_EQ
(
row
[
i
],
kPreviousLayerValue
);
}
}
}
// Tests that the component extracts a left-to-right sequence by default.
TEST_F
(
SequenceMyelinDynamicComponentTest
,
LeftToRightByDefault
)
{
TF_ASSERT_OK
(
Run
(
MakeSupportedSpec
()));
EXPECT_EQ
(
GetBackendSequenceSize
(),
kNumSteps
);
const
Matrix
<
float
>
logits
=
CaptureLogits
::
GetLogits
();
ASSERT_EQ
(
logits
.
num_rows
(),
kNumSteps
);
ASSERT_EQ
(
logits
.
num_columns
(),
kLogitsDim
);
for
(
int
i
=
0
;
i
<
kNumSteps
;
++
i
)
{
AssertOutputRow
(
logits
.
row
(
i
),
2.0
*
i
);
}
}
// Tests that the component can be explicitly configured for a left-to-right
// sequence.
TEST_F
(
SequenceMyelinDynamicComponentTest
,
LeftToRightExplicitly
)
{
ComponentSpec
component_spec
=
MakeSupportedSpec
();
(
*
component_spec
.
mutable_transition_system
()
->
mutable_parameters
())[
"left_to_right"
]
=
"true"
;
TF_ASSERT_OK
(
Run
(
component_spec
));
EXPECT_EQ
(
GetBackendSequenceSize
(),
kNumSteps
);
const
Matrix
<
float
>
logits
=
CaptureLogits
::
GetLogits
();
ASSERT_EQ
(
logits
.
num_rows
(),
kNumSteps
);
ASSERT_EQ
(
logits
.
num_columns
(),
kLogitsDim
);
for
(
int
i
=
0
;
i
<
kNumSteps
;
++
i
)
{
AssertOutputRow
(
logits
.
row
(
i
),
2.0
*
i
);
}
}
// Tests that the component can be explicitly configured for a right-to-left
// sequence.
TEST_F
(
SequenceMyelinDynamicComponentTest
,
RightToLeft
)
{
ComponentSpec
component_spec
=
MakeSupportedSpec
();
(
*
component_spec
.
mutable_transition_system
()
->
mutable_parameters
())[
"left_to_right"
]
=
"false"
;
TF_ASSERT_OK
(
Run
(
component_spec
));
EXPECT_EQ
(
GetBackendSequenceSize
(),
kNumSteps
);
const
Matrix
<
float
>
logits
=
CaptureLogits
::
GetLogits
();
ASSERT_EQ
(
logits
.
num_rows
(),
kNumSteps
);
ASSERT_EQ
(
logits
.
num_columns
(),
kLogitsDim
);
for
(
int
i
=
0
;
i
<
kNumSteps
;
++
i
)
{
const
int
reversed
=
kNumSteps
-
i
-
1
;
AssertOutputRow
(
logits
.
row
(
i
),
2.0
*
reversed
);
}
}
// Tests that the component can handle an empty sequence.
TEST_F
(
SequenceMyelinDynamicComponentTest
,
EmptySequence
)
{
EvenNumbers
::
SetNumSteps
(
0
);
LinkToZero
::
SetNumSteps
(
0
);
TF_ASSERT_OK
(
Run
(
MakeSupportedSpec
()));
EXPECT_EQ
(
GetBackendSequenceSize
(),
0
);
const
Matrix
<
float
>
logits
=
CaptureLogits
::
GetLogits
();
ASSERT_EQ
(
logits
.
num_rows
(),
0
);
}
}
// namespace
}
// namespace runtime
}
// namespace dragnn
}
// namespace syntaxnet
research/syntaxnet/dragnn/
python/perf_
test
_
data/master-spec
→
research/syntaxnet/dragnn/
runtime/myelin/
testdata/
myelination_output/
master-spec
View file @
a4bb31d0
component {
name: "
convnet
"
name: "
rnn
"
transition_system {
registered_name: "shift-only"
parameters {
key: "left_to_right"
value: "false"
}
parameters {
key: "parser_skip_deterministic"
value: "false"
}
}
resource {
name: "
lexifuse-repository
"
name: "
words-embedding-input
"
part {
file_pattern: "/cns/lg-d/home/chrisalberti/e/conv/lexifuse.lexifuse-repository/repository"
file_format: "repository"
record_format: "entity"
file_format: "tf-records"
record_format: "syntaxnet.TokenEmbedding"
}
}
resource {
name: "
brain-parser-model
"
name: "
words-vocab-input
"
part {
file_pattern: "/cns/lg-d/home/chrisalberti/e/conv/dragnn-parser.convnet.model-init/brain-parser-model"
file_format: "model"
file_format: "text"
record_format: ""
}
}
resource {
name: "
transition-system-data
"
name: "
char-ngram-map
"
part {
file_pattern: "/cns/lg-d/home/chrisalberti/e/conv/dragnn-parser.convnet.model-init/transition-system-data"
file_format: "model"
file_format: "text"
record_format: ""
}
}
resource {
name: "word
s-embedding-input
"
name: "word
-map
"
part {
file_pattern: "/readahead/512M/cns/lg-d/home/saft/corpora/word-embeddings/en/word2vec/1billion/word2vec-embedding-bi-true-32.sst"
file_format: "sstable"
record_format: "dist_belief.TokenEmbedding"
file_format: "text"
record_format: ""
}
}
resource {
name: "
words-vocab-input
"
name: "
label-map
"
part {
file_pattern: "/cns/lg-d/home/chrisalberti/e/conv/dragnn-parser.convnet.model-init/vocab"
file_format: "text"
record_format: ""
}
}
resource {
name: "
component-builder-module
"
name: "
myelin-flow
"
part {
file_pattern: "/cns/lg-d/home/chrisalberti/e/conv/dragnn-parser.convnet.component-builder-module/module-spec"
file_format: "pbtxt"
record_format: ""
file_format: "model"
record_format: "sling.myelin.Flow"
}
}
fixed_feature {
name: "char_ngram"
fml: "input.token.lexifuse-char-ngram"
embedding_dim: 16
vocabulary_size: 16500
size: 1
predicate_map: "hashed"
name: "char_ngrams"
fml: "input.token { offset(-1).char-ngram(min-length=1,max-length=3,mark-boundaries=true) offset(0).char-ngram(min-length=1,max-length=3,mark-boundaries=true) offset(1).char-ngram(min-length=1,max-length=3,mark-boundaries=true) }"
embedding_dim: -1
vocabulary_size: 25788
size: 3
}
fixed_feature {
name: "words"
fml: "input.
word
"
embedding_dim:
32
vocabulary_size:
39395
fml: "input.
token.word(min-freq=2)
"
embedding_dim:
-1
vocabulary_size:
23769
size: 1
predicate_map: "hashed"
}
network_unit {
registered_name: "IdentityNetwork"
registered_name: "LSTMNetwork"
parameters {
key: "hidden_layer_sizes"
value: "128"
}
parameters {
key: "omit_logits"
value: "true"
}
}
backend {
registered_name: "
Parser
Component"
registered_name: "
SyntaxNet
Component"
}
num_actions: 1
attention_component: ""
component_builder {
registered_name: "components.common.dragnn.python.conv_component.ConvComponentBuilder"
parameters {
key: "depths"
value: "48,128"
}
parameters {
key: "output_dims"
value: "45"
}
parameters {
key: "widths"
value: "7"
registered_name: "MyelinDynamicComponent"
}
}
training_beam_size: 1
inference_beam_size: 1
}
component {
name: "tagger"
...
...
@@ -109,63 +99,62 @@ component {
resource {
name: "tag-map"
part {
file_pattern: "/cns/lg-d/home/chrisalberti/e/conv/lexifuse.lexicon/tag-map"
file_format: "text"
record_format: ""
}
}
resource {
name: "
lexifuse-reposit
ory"
name: "
tag-to-categ
ory"
part {
file_pattern: "/cns/lg-d/home/chrisalberti/e/conv/lexifuse.lexifuse-repository/repository"
file_format: "repository"
record_format: "entity"
file_format: "text"
record_format: ""
}
}
resource {
name: "
brain-parser-model
"
name: "
label-map
"
part {
file_pattern: "/cns/lg-d/home/chrisalberti/e/conv/dragnn-parser.tagger.model-init/brain-parser-model"
file_format: "model"
file_format: "text"
record_format: ""
}
}
resource {
name: "
transition-system-data
"
name: "
myelin-flow
"
part {
file_pattern: "/cns/lg-d/home/chrisalberti/e/conv/dragnn-parser.tagger.model-init/transition-system-data"
file_format: "model"
record_format: ""
}
record_format: "sling.myelin.Flow"
}
resource {
name: "component-builder-module"
part {
file_pattern: "/cns/lg-d/home/chrisalberti/e/conv/dragnn-parser.tagger.component-builder-module/module-spec"
file_format: "pbtxt"
record_format: ""
}
linked_feature {
name: "recurrence"
fml: "bias(0)"
embedding_dim: -1
size: 1
source_component: "tagger"
source_translator: "history"
source_layer: "layer_0"
}
linked_feature {
name: "
convnet
"
name: "
rnn
"
fml: "input.focus"
embedding_dim: -1
size: 1
source_component: "
convnet
"
source_translator: "
identity
"
source_layer: "
conv0_logits
"
source_component: "
rnn
"
source_translator: "
reverse-token
"
source_layer: "
layer_0
"
}
network_unit {
registered_name: "IdentityNetwork"
registered_name: "FeedForwardNetwork"
parameters {
key: "hidden_layer_sizes"
value: "64,64"
}
}
backend {
registered_name: "
Parser
Component"
registered_name: "
SyntaxNet
Component"
}
num_actions: 45
attention_component: ""
component_builder {
registered_name: "
bulk_component.BulkAnnotator
Component
Builder
"
registered_name: "
MyelinDynamic
Component"
}
training_beam_size: 1
inference_beam_size: 1
}
research/syntaxnet/dragnn/runtime/myelin/testdata/myelination_output/rnn.flow
0 → 100644
View file @
a4bb31d0
File added
research/syntaxnet/dragnn/runtime/myelin/testdata/myelination_output/tagger.flow
0 → 100644
View file @
a4bb31d0
File added
research/syntaxnet/dragnn/runtime/network_states.cc
0 → 100644
View file @
a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/network_states.h"
#include "tensorflow/core/lib/core/errors.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
namespace
{
// Returns the first value in |container| whose ".name" field is |name|, or null
// if not found.
template
<
class
Container
>
const
typename
Container
::
value_type
*
Find
(
const
Container
&
container
,
const
string
&
name
)
{
for
(
auto
&
value
:
container
)
{
if
(
value
.
name
==
name
)
return
&
value
;
}
return
nullptr
;
}
}
// namespace
tensorflow
::
Status
NetworkStateManager
::
AddComponent
(
const
string
&
name
)
{
if
(
Find
(
components_
,
name
)
!=
nullptr
)
{
return
tensorflow
::
errors
::
FailedPrecondition
(
"Component '"
,
name
,
"' already exists"
);
}
// Success; make modifications.
components_
.
emplace_back
(
name
);
return
tensorflow
::
Status
::
OK
();
}
tensorflow
::
Status
NetworkStateManager
::
AddLayerImpl
(
const
string
&
name
,
std
::
type_index
type
,
bool
is_pairwise
,
size_t
bytes
,
size_t
*
component_index
,
OperandHandle
*
operand_handle
)
{
if
(
components_
.
empty
())
{
return
tensorflow
::
errors
::
FailedPrecondition
(
"No current component"
);
}
ComponentConfig
&
component
=
components_
.
back
();
if
(
Find
(
component
.
layers
,
name
)
!=
nullptr
)
{
return
tensorflow
::
errors
::
FailedPrecondition
(
"Layer '"
,
name
,
"' already exists in component '"
,
component
.
name
,
"'"
);
}
if
(
component
.
aliases
.
find
(
name
)
!=
component
.
aliases
.
end
())
{
return
tensorflow
::
errors
::
FailedPrecondition
(
"Layer '"
,
name
,
"' conflicts with an existing alias in component '"
,
component
.
name
,
"'"
);
}
// Success; make modifications.
const
OperandType
operand_type
=
is_pairwise
?
OperandType
::
kPairwise
:
OperandType
::
kStepwise
;
*
component_index
=
components_
.
size
()
-
1
;
*
operand_handle
=
component
.
manager
.
Add
({
operand_type
,
bytes
});
component
.
layers
.
emplace_back
(
name
,
type
,
*
operand_handle
);
return
tensorflow
::
Status
::
OK
();
}
tensorflow
::
Status
NetworkStateManager
::
AddLayerAlias
(
const
string
&
alias
,
const
string
&
name
)
{
if
(
components_
.
empty
())
{
return
tensorflow
::
errors
::
FailedPrecondition
(
"No current component"
);
}
ComponentConfig
&
component
=
components_
.
back
();
if
(
Find
(
component
.
layers
,
name
)
==
nullptr
)
{
return
tensorflow
::
errors
::
FailedPrecondition
(
"Target layer '"
,
name
,
"' of alias '"
,
alias
,
"' does not exist in component '"
,
component
.
name
,
"'"
);
}
if
(
Find
(
component
.
layers
,
alias
)
!=
nullptr
)
{
return
tensorflow
::
errors
::
FailedPrecondition
(
"Alias '"
,
alias
,
"' conflicts with an existing layer in component '"
,
component
.
name
,
"'"
);
}
if
(
component
.
aliases
.
find
(
alias
)
!=
component
.
aliases
.
end
())
{
return
tensorflow
::
errors
::
FailedPrecondition
(
"Alias '"
,
alias
,
"' already exists in component '"
,
component
.
name
,
"'"
);
}
// Success; make modifications.
component
.
aliases
[
alias
]
=
name
;
return
tensorflow
::
Status
::
OK
();
}
tensorflow
::
Status
NetworkStateManager
::
AddLocalImpl
(
const
OperandSpec
&
spec
,
OperandHandle
*
handle
)
{
if
(
components_
.
empty
())
{
return
tensorflow
::
errors
::
FailedPrecondition
(
"No current component"
);
}
ComponentConfig
&
component
=
components_
.
back
();
// Success; make modifications.
*
handle
=
component
.
manager
.
Add
(
spec
);
return
tensorflow
::
Status
::
OK
();
}
tensorflow
::
Status
NetworkStateManager
::
LookupLayerImpl
(
const
string
&
component_name
,
const
string
&
layer_name_or_alias
,
std
::
type_index
type
,
bool
is_pairwise
,
size_t
*
bytes
,
size_t
*
component_index
,
OperandHandle
*
operand_handle
)
const
{
const
ComponentConfig
*
component
=
Find
(
components_
,
component_name
);
if
(
component
==
nullptr
)
{
return
tensorflow
::
errors
::
FailedPrecondition
(
"Unknown component '"
,
component_name
,
"'"
);
}
// If necessary, resolve a layer alias into a layer name. Note that aliases
// are non-transitive, since AddLayerAlias() requires that the target of the
// alias is a layer.
const
auto
it
=
component
->
aliases
.
find
(
layer_name_or_alias
);
const
string
&
layer_name
=
it
!=
component
->
aliases
.
end
()
?
it
->
second
:
layer_name_or_alias
;
const
LayerConfig
*
layer
=
Find
(
component
->
layers
,
layer_name
);
if
(
layer
==
nullptr
)
{
return
tensorflow
::
errors
::
FailedPrecondition
(
"Unknown layer '"
,
layer_name
,
"' in component '"
,
component_name
,
"'"
);
}
if
(
layer
->
type
!=
type
)
{
return
tensorflow
::
errors
::
InvalidArgument
(
"Layer '"
,
layer_name
,
"' in component '"
,
component_name
,
"' does not match its expected type"
);
}
const
OperandType
required_type
=
is_pairwise
?
OperandType
::
kPairwise
:
OperandType
::
kStepwise
;
const
OperandSpec
&
operand_spec
=
component
->
manager
.
spec
(
layer
->
handle
);
if
(
operand_spec
.
type
!=
required_type
)
{
return
tensorflow
::
errors
::
InvalidArgument
(
"Layer '"
,
layer_name
,
"' in component '"
,
component_name
,
"' does not match its expected OperandType"
);
}
// Success; make modifications.
*
bytes
=
operand_spec
.
size
;
*
component_index
=
component
-
components_
.
data
();
*
operand_handle
=
layer
->
handle
;
return
tensorflow
::
Status
::
OK
();
}
void
NetworkStates
::
Reset
(
const
NetworkStateManager
*
manager
)
{
manager_
=
manager
;
num_active_components_
=
0
;
// Never shrink the |component_operands_|, to avoid deallocating (and then
// eventually reallocating) operand arrays.
if
(
manager_
->
components_
.
size
()
>
component_operands_
.
size
())
{
component_operands_
.
resize
(
manager_
->
components_
.
size
());
}
}
tensorflow
::
Status
NetworkStates
::
StartNextComponent
(
size_t
pre_allocate_num_steps
)
{
if
(
manager_
==
nullptr
)
{
return
tensorflow
::
errors
::
FailedPrecondition
(
"No manager"
);
}
if
(
num_active_components_
>=
manager_
->
components_
.
size
())
{
return
tensorflow
::
errors
::
OutOfRange
(
"No next component"
);
}
// Success; make modifications.
const
OperandManager
*
operand_manager
=
&
manager_
->
components_
[
num_active_components_
].
manager
;
component_operands_
[
num_active_components_
].
Reset
(
operand_manager
,
pre_allocate_num_steps
);
++
num_active_components_
;
return
tensorflow
::
Status
::
OK
();
}
}
// namespace runtime
}
// namespace dragnn
}
// namespace syntaxnet
research/syntaxnet/dragnn/runtime/network_states.h
0 → 100644
View file @
a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
// Utils for declaring, allocating, and retrieving network states, similar to
// the "NetworkState" class and the "network_states" argument to the build_*()
// methods of ComponentBuilderBase; see component.py.
//
// In brief, a DRAGNN network consists of a sequence of named components, each
// of which produces a set of named output layers. Each component can access
// its own layers as well as those of preceding components. Components can also
// access "local operands", which are like layers but private to that particular
// component. Local operands can be useful for, e.g., caching an intermediate
// result in a complex computation.
//
// For example, suppose a network has two components: "tagger" and "parser",
// where the parser uses the hidden activations of the tagger. In this case,
// the tagger can add a layer called "hidden" at init time and fill that layer
// at processing time. Corespondingly, the parser can look for a layer called
// "hidden" in the "tagger" component at init time, and read the activations at
// processing time. (Note that for convenience, such links should be handled
// using the utils in linked_embeddings.h).
//
// As another example, suppose we are implementing an LSTM and we wish to keep
// the cell state private. In this case, the LSTM component could add a layer
// for exporting the hidden activations and a local matrix for the sequence of
// cell states. A more compact approach is to use two local vectors instead,
// one for even steps and the other for odd steps.
#ifndef DRAGNN_RUNTIME_NETWORK_STATES_H_
#define DRAGNN_RUNTIME_NETWORK_STATES_H_
#include <stddef.h>
#include <stdint.h>
#include <map>
#include <string>
#include <typeindex>
#include <typeinfo>
#include <vector>
#include "dragnn/runtime/alignment.h"
#include "dragnn/runtime/math/types.h"
#include "dragnn/runtime/operands.h"
#include "syntaxnet/base.h"
#include "tensorflow/core/lib/core/status.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
// Opaque handles used to access typed layers or local operands.
template
<
class
T
>
class
LayerHandle
;
template
<
class
T
>
class
PairwiseLayerHandle
;
template
<
class
T
>
class
LocalVectorHandle
;
template
<
class
T
>
class
LocalMatrixHandle
;
// A class that manages the state of a DRAGNN network and associates each layer
// and local operand with a handle. Layer and local operand contents can be
// retrieved using these handles; see NetworkStates below.
class
NetworkStateManager
{
public:
// Creates an empty manager.
NetworkStateManager
()
=
default
;
// Adds a component named |name| and makes it the current component. The
// |name| must be unique in the network. Components are sequenced in the
// order they are added. On error, returns non-OK and modifies nothing.
tensorflow
::
Status
AddComponent
(
const
string
&
name
);
// Adds a layer named |name| to the current component and sets |handle| to its
// handle. The |name| must be unique in the current component. The layer is
// realized as a Matrix<T> with one row per step and |dimension| columns. On
// error, returns non-OK and modifies nothing.
template
<
class
T
>
tensorflow
::
Status
AddLayer
(
const
string
&
name
,
size_t
dimension
,
LayerHandle
<
T
>
*
handle
);
// As above, but for pairwise layers.
template
<
class
T
>
tensorflow
::
Status
AddLayer
(
const
string
&
name
,
size_t
dimension
,
PairwiseLayerHandle
<
T
>
*
handle
);
// As above, but for a local Vector<T> or Matrix<T> operand. The operand is
// "local" in the sense that only the caller knows its handle.
template
<
class
T
>
tensorflow
::
Status
AddLocal
(
size_t
dimension
,
LocalVectorHandle
<
T
>
*
handle
);
template
<
class
T
>
tensorflow
::
Status
AddLocal
(
size_t
dimension
,
LocalMatrixHandle
<
T
>
*
handle
);
// Makes |alias| an alias of the layer named |name| in the current component,
// so that lookups of |alias| resolve to |name|. The |name| must already
// exist as a layer, and layer names and aliases must be unique within each
// component. On error, returns non-OK and modifies nothing.
tensorflow
::
Status
AddLayerAlias
(
const
string
&
alias
,
const
string
&
name
);
// Finds the layer that matches |layer_name_or_alias| in the component named
// |component_name|. Sets |dimension| to its dimension and |handle| to its
// handle. On error, returns non-OK and modifies nothing.
template
<
class
T
>
tensorflow
::
Status
LookupLayer
(
const
string
&
component_name
,
const
string
&
layer_name_or_alias
,
size_t
*
dimension
,
LayerHandle
<
T
>
*
handle
)
const
;
// As above, but for pairwise layers.
template
<
class
T
>
tensorflow
::
Status
LookupLayer
(
const
string
&
component_name
,
const
string
&
layer_name_or_alias
,
size_t
*
dimension
,
PairwiseLayerHandle
<
T
>
*
handle
)
const
;
private:
friend
class
NetworkStates
;
// Configuration information for a layer.
struct
LayerConfig
{
// Creates a config for a layer with the |name|, |type| ID, and |handle|.
LayerConfig
(
const
string
&
name
,
std
::
type_index
type
,
OperandHandle
handle
)
:
name
(
name
),
type
(
type
),
handle
(
handle
)
{}
// Name of the layer.
string
name
;
// Type ID of the layer contents.
std
::
type_index
type
;
// Handle of the operand that holds the layer contents.
OperandHandle
handle
;
};
// Configuration information for a component.
struct
ComponentConfig
{
// Creates an empty config for a component with the |name|.
explicit
ComponentConfig
(
const
string
&
name
)
:
name
(
name
)
{}
// Name of the component.
string
name
;
// Manager for the operands used by the component.
OperandManager
manager
;
// Configuration of each layer produced by the component.
std
::
vector
<
LayerConfig
>
layers
;
// Mapping from layer alias to layer name in the component.
std
::
map
<
string
,
string
>
aliases
;
};
// Implements the non-templated part of AddLayer(). Adds a layer with the
// |name|, |type| ID, and size in |bytes|. Sets the |component_index| and
// |operand_handle| according to the containing component and operand. If
// |is_pairwise| is true, then the new layer is pairwise (vs stepwise). On
// error, returns non-OK and modifies nothing.
tensorflow
::
Status
AddLayerImpl
(
const
string
&
name
,
std
::
type_index
type
,
bool
is_pairwise
,
size_t
bytes
,
size_t
*
component_index
,
OperandHandle
*
operand_handle
);
// Implements the non-templated portion of AddLocal*(). Adds a local operand
// with the |spec| and sets |handle| to its handle. On error, returns non-OK
// and modifies nothing.
tensorflow
::
Status
AddLocalImpl
(
const
OperandSpec
&
spec
,
OperandHandle
*
handle
);
// Implements the non-templated portion of LookupLayer(). Finds the layer
// that matches the |component_name| and |layer_name_or_alias|. That layer
// must match the |type| ID. Sets |bytes| to its size, |component_index| to
// the index of its containing component, and |operand_handle| to the handle
// of its underlying operand. If |is_pairwise| is true, then the layer must
// be pairwise (vs stepwise). On error, returns non-OK and modifies nothing.
tensorflow
::
Status
LookupLayerImpl
(
const
string
&
component_name
,
const
string
&
layer_name_or_alias
,
std
::
type_index
type
,
bool
is_pairwise
,
size_t
*
bytes
,
size_t
*
component_index
,
OperandHandle
*
operand_handle
)
const
;
// Ordered list of configurations for the components in the network.
std
::
vector
<
ComponentConfig
>
components_
;
};
// A set of network states. The structure of the network is configured by a
// NetworkStateManager, and layer and local operand contents can be accessed
// using the handles produced by the manager.
//
// Multiple NetworkStates instances can share the same NetworkStateManager. In
// addition, a NetworkStates instance can be reused by repeatedly Reset()-ing
// it, potentially with different NetworkStateManagers. Such reuse can reduce
// allocation overhead.
class
NetworkStates
{
public:
// Creates an uninitialized set of states.
NetworkStates
()
=
default
;
// Resets this to an empty set configured by the |manager|. The |manager|
// must live until this is destroyed or Reset(), and should not be modified
// during that time. No current component is set; call StartNextComponent()
// to start the first component.
void
Reset
(
const
NetworkStateManager
*
manager
);
// Starts the next component and makes it the current component. Initially,
// the component has zero steps but more can be added using AddStep(). Uses
// |pre_allocate_num_steps| to pre-allocate storage; see Operands::Reset().
// On error, returns non-OK and modifies nothing.
tensorflow
::
Status
StartNextComponent
(
size_t
pre_allocate_num_steps
);
// Adds one or more steps to the current component. Invalidates all
// previously-returned matrices of the current component.
void
AddStep
()
{
AddSteps
(
1
);
}
void
AddSteps
(
size_t
num_steps
);
// Returns the layer associated with the |handle|.
template
<
class
T
>
MutableMatrix
<
T
>
GetLayer
(
LayerHandle
<
T
>
handle
)
const
;
// Returns the pairwise layer associated with the |handle|.
template
<
class
T
>
MutableMatrix
<
T
>
GetLayer
(
PairwiseLayerHandle
<
T
>
handle
)
const
;
// Returns the local vector or matrix associated with the |handle| in the
// current component.
template
<
class
T
>
MutableVector
<
T
>
GetLocal
(
LocalVectorHandle
<
T
>
handle
)
const
;
template
<
class
T
>
MutableMatrix
<
T
>
GetLocal
(
LocalMatrixHandle
<
T
>
handle
)
const
;
private:
// Manager of this set of network states.
const
NetworkStateManager
*
manager_
=
nullptr
;
// Number of active components in the |component_operands_|.
size_t
num_active_components_
=
0
;
// Ordered list of per-component operands. Only the first
// |num_active_components_| entries are valid.
std
::
vector
<
Operands
>
component_operands_
;
};
// Implementation details below.
// An opaque handle to a typed layer of some component.
template
<
class
T
>
class
LayerHandle
{
public:
static_assert
(
IsAlignable
<
T
>
(),
"T must be alignable"
);
// Creates an invalid handle.
LayerHandle
()
=
default
;
private:
friend
class
NetworkStateManager
;
friend
class
NetworkStates
;
// Index of the containing component in the network state manager.
size_t
component_index_
=
SIZE_MAX
;
// Handle of the operand holding the layer.
OperandHandle
operand_handle_
;
};
// An opaque handle to a typed pairwise layer of some component.
template
<
class
T
>
class
PairwiseLayerHandle
{
public:
static_assert
(
IsAlignable
<
T
>
(),
"T must be alignable"
);
// Creates an invalid handle.
PairwiseLayerHandle
()
=
default
;
private:
friend
class
NetworkStateManager
;
friend
class
NetworkStates
;
// Index of the containing component in the network state manager.
size_t
component_index_
=
SIZE_MAX
;
// Handle of the operand holding the layer.
OperandHandle
operand_handle_
;
};
// An opaque handle to a typed local operand of some component.
template
<
class
T
>
class
LocalVectorHandle
{
public:
static_assert
(
IsAlignable
<
T
>
(),
"T must be alignable"
);
// Creates an invalid handle.
LocalVectorHandle
()
=
default
;
private:
friend
class
NetworkStateManager
;
friend
class
NetworkStates
;
// Handle of the local operand.
OperandHandle
operand_handle_
;
};
// An opaque handle to a typed local operand of some component.
template
<
class
T
>
class
LocalMatrixHandle
{
public:
static_assert
(
IsAlignable
<
T
>
(),
"T must be alignable"
);
// Creates an invalid handle.
LocalMatrixHandle
()
=
default
;
private:
friend
class
NetworkStateManager
;
friend
class
NetworkStates
;
// Handle of the local operand.
OperandHandle
operand_handle_
;
};
template
<
class
T
>
tensorflow
::
Status
NetworkStateManager
::
AddLayer
(
const
string
&
name
,
size_t
dimension
,
LayerHandle
<
T
>
*
handle
)
{
return
AddLayerImpl
(
name
,
std
::
type_index
(
typeid
(
T
)),
/*is_pairwise=*/
false
,
dimension
*
sizeof
(
T
),
&
handle
->
component_index_
,
&
handle
->
operand_handle_
);
}
template
<
class
T
>
tensorflow
::
Status
NetworkStateManager
::
AddLayer
(
const
string
&
name
,
size_t
dimension
,
PairwiseLayerHandle
<
T
>
*
handle
)
{
return
AddLayerImpl
(
name
,
std
::
type_index
(
typeid
(
T
)),
/*is_pairwise=*/
true
,
dimension
*
sizeof
(
T
),
&
handle
->
component_index_
,
&
handle
->
operand_handle_
);
}
template
<
class
T
>
tensorflow
::
Status
NetworkStateManager
::
AddLocal
(
size_t
dimension
,
LocalVectorHandle
<
T
>
*
handle
)
{
return
AddLocalImpl
({
OperandType
::
kSingular
,
dimension
*
sizeof
(
T
)},
&
handle
->
operand_handle_
);
}
template
<
class
T
>
tensorflow
::
Status
NetworkStateManager
::
AddLocal
(
size_t
dimension
,
LocalMatrixHandle
<
T
>
*
handle
)
{
return
AddLocalImpl
({
OperandType
::
kStepwise
,
dimension
*
sizeof
(
T
)},
&
handle
->
operand_handle_
);
}
template
<
class
T
>
tensorflow
::
Status
NetworkStateManager
::
LookupLayer
(
const
string
&
component_name
,
const
string
&
layer_name_or_alias
,
size_t
*
dimension
,
LayerHandle
<
T
>
*
handle
)
const
{
TF_RETURN_IF_ERROR
(
LookupLayerImpl
(
component_name
,
layer_name_or_alias
,
std
::
type_index
(
typeid
(
T
)),
/*is_pairwise=*/
false
,
dimension
,
&
handle
->
component_index_
,
&
handle
->
operand_handle_
));
DCHECK_EQ
(
*
dimension
%
sizeof
(
T
),
0
);
*
dimension
/=
sizeof
(
T
);
// bytes => Ts
return
tensorflow
::
Status
::
OK
();
}
template
<
class
T
>
tensorflow
::
Status
NetworkStateManager
::
LookupLayer
(
const
string
&
component_name
,
const
string
&
layer_name_or_alias
,
size_t
*
dimension
,
PairwiseLayerHandle
<
T
>
*
handle
)
const
{
TF_RETURN_IF_ERROR
(
LookupLayerImpl
(
component_name
,
layer_name_or_alias
,
std
::
type_index
(
typeid
(
T
)),
/*is_pairwise=*/
true
,
dimension
,
&
handle
->
component_index_
,
&
handle
->
operand_handle_
));
DCHECK_EQ
(
*
dimension
%
sizeof
(
T
),
0
);
*
dimension
/=
sizeof
(
T
);
// bytes => Ts
return
tensorflow
::
Status
::
OK
();
}
inline
void
NetworkStates
::
AddSteps
(
size_t
num_steps
)
{
component_operands_
[
num_active_components_
-
1
].
AddSteps
(
num_steps
);
}
template
<
class
T
>
MutableMatrix
<
T
>
NetworkStates
::
GetLayer
(
LayerHandle
<
T
>
handle
)
const
{
return
MutableMatrix
<
T
>
(
component_operands_
[
handle
.
component_index_
].
GetStepwise
(
handle
.
operand_handle_
));
}
template
<
class
T
>
MutableMatrix
<
T
>
NetworkStates
::
GetLayer
(
PairwiseLayerHandle
<
T
>
handle
)
const
{
return
MutableMatrix
<
T
>
(
component_operands_
[
handle
.
component_index_
].
GetPairwise
(
handle
.
operand_handle_
));
}
template
<
class
T
>
MutableVector
<
T
>
NetworkStates
::
GetLocal
(
LocalVectorHandle
<
T
>
handle
)
const
{
return
MutableVector
<
T
>
(
component_operands_
[
num_active_components_
-
1
].
GetSingular
(
handle
.
operand_handle_
));
}
template
<
class
T
>
MutableMatrix
<
T
>
NetworkStates
::
GetLocal
(
LocalMatrixHandle
<
T
>
handle
)
const
{
return
MutableMatrix
<
T
>
(
component_operands_
[
num_active_components_
-
1
].
GetStepwise
(
handle
.
operand_handle_
));
}
}
// namespace runtime
}
// namespace dragnn
}
// namespace syntaxnet
#endif // DRAGNN_RUNTIME_NETWORK_STATES_H_
research/syntaxnet/dragnn/runtime/network_states_test.cc
0 → 100644
View file @
a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/network_states.h"
#include <stddef.h>
#include <string.h>
#include <string>
#include <vector>
#include "dragnn/core/test/generic.h"
#include "dragnn/runtime/alignment.h"
#include "dragnn/runtime/math/types.h"
#include "syntaxnet/base.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/test.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
namespace
{
// Expects that two objects have identical bit representations.
template
<
class
T
>
void
ExpectBitwiseEqual
(
const
T
&
object1
,
const
T
&
object2
)
{
EXPECT_EQ
(
memcmp
(
&
object1
,
&
object2
,
sizeof
(
T
)),
0
);
}
// Expects that the |matrix| has the given dimensions.
template
<
class
T
>
void
ExpectDimensions
(
MutableMatrix
<
T
>
matrix
,
size_t
num_rows
,
size_t
num_columns
)
{
EXPECT_EQ
(
matrix
.
num_rows
(),
num_rows
);
EXPECT_EQ
(
matrix
.
num_columns
(),
num_columns
);
}
// Sets the |vector| to |size| copies of the |value|.
template
<
class
T
>
void
Fill
(
MutableVector
<
T
>
vector
,
size_t
size
,
T
value
)
{
ASSERT_EQ
(
vector
.
size
(),
size
);
for
(
T
&
element
:
vector
)
element
=
value
;
}
// Expects that the |vector| contains |size| copies of the |expected_value|.
template
<
class
T
>
void
ExpectFilled
(
MutableVector
<
T
>
vector
,
size_t
size
,
T
expected_value
)
{
ASSERT_EQ
(
vector
.
size
(),
size
);
for
(
const
T
element
:
vector
)
EXPECT_EQ
(
element
,
expected_value
);
}
// Tests that NetworkStateManager can add a named component.
TEST
(
NetworkStateManagerTest
,
AddComponent
)
{
NetworkStateManager
manager
;
TF_EXPECT_OK
(
manager
.
AddComponent
(
"foo/bar"
));
EXPECT_THAT
(
manager
.
AddComponent
(
"foo/bar"
),
test
::
IsErrorWithSubstr
(
"Component 'foo/bar' already exists"
));
// Empty component name is weird, but OK.
TF_EXPECT_OK
(
manager
.
AddComponent
(
""
));
EXPECT_THAT
(
manager
.
AddComponent
(
""
),
test
::
IsErrorWithSubstr
(
"Component '' already exists"
));
}
// Tests that NetworkStateManager can add a named layer to the current
// component.
TEST
(
NetworkStateManagerTest
,
AddLayer
)
{
NetworkStateManager
manager
;
LayerHandle
<
float
>
unused_layer_handle
;
EXPECT_THAT
(
manager
.
AddLayer
(
"layer"
,
1
,
&
unused_layer_handle
),
test
::
IsErrorWithSubstr
(
"No current component"
));
TF_EXPECT_OK
(
manager
.
AddComponent
(
"component"
));
TF_EXPECT_OK
(
manager
.
AddLayer
(
"layer"
,
2
,
&
unused_layer_handle
));
EXPECT_THAT
(
manager
.
AddLayer
(
"layer"
,
2
,
&
unused_layer_handle
),
test
::
IsErrorWithSubstr
(
"Layer 'layer' already exists in component 'component'"
));
}
// Tests that NetworkStateManager can add a named pairwise layer to the current
// component.
TEST
(
NetworkStateManagerTest
,
AddLayerPairwise
)
{
NetworkStateManager
manager
;
PairwiseLayerHandle
<
float
>
unused_layer_handle
;
EXPECT_THAT
(
manager
.
AddLayer
(
"layer"
,
1
,
&
unused_layer_handle
),
test
::
IsErrorWithSubstr
(
"No current component"
));
TF_EXPECT_OK
(
manager
.
AddComponent
(
"component"
));
TF_EXPECT_OK
(
manager
.
AddLayer
(
"layer"
,
2
,
&
unused_layer_handle
));
EXPECT_THAT
(
manager
.
AddLayer
(
"layer"
,
2
,
&
unused_layer_handle
),
test
::
IsErrorWithSubstr
(
"Layer 'layer' already exists in component 'component'"
));
}
// Tests that NetworkStateManager can add an alias to an existing layer. Also
// tests that layer and alias names are required to be unique.
TEST
(
NetworkStateManagerTest
,
AddLayerAlias
)
{
NetworkStateManager
manager
;
LayerHandle
<
float
>
unused_layer_handle
;
EXPECT_THAT
(
manager
.
AddLayerAlias
(
"alias"
,
"layer"
),
test
::
IsErrorWithSubstr
(
"No current component"
));
TF_EXPECT_OK
(
manager
.
AddComponent
(
"component"
));
EXPECT_THAT
(
manager
.
AddLayerAlias
(
"alias"
,
"layer"
),
test
::
IsErrorWithSubstr
(
"Target layer 'layer' of alias 'alias' does not "
"exist in component 'component'"
));
TF_EXPECT_OK
(
manager
.
AddLayer
(
"layer"
,
2
,
&
unused_layer_handle
));
TF_EXPECT_OK
(
manager
.
AddLayerAlias
(
"alias"
,
"layer"
));
EXPECT_THAT
(
manager
.
AddLayerAlias
(
"alias"
,
"layer"
),
test
::
IsErrorWithSubstr
(
"Alias 'alias' already exists in component 'component'"
));
EXPECT_THAT
(
manager
.
AddLayer
(
"alias"
,
2
,
&
unused_layer_handle
),
test
::
IsErrorWithSubstr
(
"Layer 'alias' conflicts with an existing alias "
"in component 'component'"
));
TF_EXPECT_OK
(
manager
.
AddLayer
(
"layer2"
,
2
,
&
unused_layer_handle
));
EXPECT_THAT
(
manager
.
AddLayerAlias
(
"layer2"
,
"layer"
),
test
::
IsErrorWithSubstr
(
"Alias 'layer2' conflicts with an existing layer "
"in component 'component'"
));
}
// Tests that NetworkStateManager can add a local matrix or vector to the
// current component.
TEST
(
NetworkStateManagerTest
,
AddLocal
)
{
NetworkStateManager
manager
;
LocalVectorHandle
<
float
>
unused_local_vector_handle
;
LocalMatrixHandle
<
float
>
unused_local_matrix_handle
;
EXPECT_THAT
(
manager
.
AddLocal
(
11
,
&
unused_local_matrix_handle
),
test
::
IsErrorWithSubstr
(
"No current component"
));
TF_EXPECT_OK
(
manager
.
AddComponent
(
"component"
));
TF_EXPECT_OK
(
manager
.
AddLocal
(
22
,
&
unused_local_matrix_handle
));
TF_EXPECT_OK
(
manager
.
AddLocal
(
33
,
&
unused_local_vector_handle
));
}
// Tests that NetworkStateManager can look up existing layers or aliases, and
// fails on invalid layer or component names and for mismatched types.
TEST
(
NetworkStateManagerTest
,
LookupLayer
)
{
NetworkStateManager
manager
;
LayerHandle
<
char
>
char_handle
;
LayerHandle
<
int16
>
int16_handle
;
LayerHandle
<
uint16
>
uint16_handle
;
PairwiseLayerHandle
<
char
>
pairwise_char_handle
;
size_t
dimension
=
0
;
// Add some typed layers and aliases.
TF_ASSERT_OK
(
manager
.
AddComponent
(
"foo"
));
TF_ASSERT_OK
(
manager
.
AddLayer
(
"char"
,
5
,
&
char_handle
));
TF_ASSERT_OK
(
manager
.
AddLayer
(
"int16"
,
7
,
&
int16_handle
));
TF_ASSERT_OK
(
manager
.
AddLayerAlias
(
"char_alias"
,
"char"
));
TF_ASSERT_OK
(
manager
.
AddLayerAlias
(
"int16_alias"
,
"int16"
));
TF_ASSERT_OK
(
manager
.
AddComponent
(
"bar"
));
TF_ASSERT_OK
(
manager
.
AddLayer
(
"uint16"
,
11
,
&
uint16_handle
));
TF_ASSERT_OK
(
manager
.
AddLayer
(
"pairwise_char"
,
13
,
&
pairwise_char_handle
));
TF_ASSERT_OK
(
manager
.
AddLayerAlias
(
"uint16_alias"
,
"uint16"
));
TF_ASSERT_OK
(
manager
.
AddLayerAlias
(
"pairwise_char_alias"
,
"pairwise_char"
));
// Try looking up unknown components.
EXPECT_THAT
(
manager
.
LookupLayer
(
"missing"
,
"char"
,
&
dimension
,
&
char_handle
),
test
::
IsErrorWithSubstr
(
"Unknown component 'missing'"
));
EXPECT_THAT
(
manager
.
LookupLayer
(
"baz"
,
"float"
,
&
dimension
,
&
char_handle
),
test
::
IsErrorWithSubstr
(
"Unknown component 'baz'"
));
// Try looking up valid components but unknown layers.
EXPECT_THAT
(
manager
.
LookupLayer
(
"foo"
,
"missing"
,
&
dimension
,
&
char_handle
),
test
::
IsErrorWithSubstr
(
"Unknown layer 'missing' in component 'foo'"
));
EXPECT_THAT
(
manager
.
LookupLayer
(
"bar"
,
"missing"
,
&
dimension
,
&
char_handle
),
test
::
IsErrorWithSubstr
(
"Unknown layer 'missing' in component 'bar'"
));
// Try looking up valid components and the names of layers or aliases in the
// other components.
EXPECT_THAT
(
manager
.
LookupLayer
(
"foo"
,
"uint16"
,
&
dimension
,
&
uint16_handle
),
test
::
IsErrorWithSubstr
(
"Unknown layer 'uint16' in component 'foo'"
));
EXPECT_THAT
(
manager
.
LookupLayer
(
"foo"
,
"uint16_alias"
,
&
dimension
,
&
uint16_handle
),
test
::
IsErrorWithSubstr
(
"Unknown layer 'uint16_alias' in component 'foo'"
));
EXPECT_THAT
(
manager
.
LookupLayer
(
"bar"
,
"char"
,
&
dimension
,
&
char_handle
),
test
::
IsErrorWithSubstr
(
"Unknown layer 'char' in component 'bar'"
));
EXPECT_THAT
(
manager
.
LookupLayer
(
"bar"
,
"char_alias"
,
&
dimension
,
&
char_handle
),
test
::
IsErrorWithSubstr
(
"Unknown layer 'char_alias' in component 'bar'"
));
// Look up layers with incorrect types.
EXPECT_THAT
(
manager
.
LookupLayer
(
"foo"
,
"char"
,
&
dimension
,
&
int16_handle
),
test
::
IsErrorWithSubstr
(
"Layer 'char' in component 'foo' does not match its expected type"
));
EXPECT_THAT
(
manager
.
LookupLayer
(
"foo"
,
"char"
,
&
dimension
,
&
uint16_handle
),
test
::
IsErrorWithSubstr
(
"Layer 'char' in component 'foo' does not match its expected type"
));
EXPECT_THAT
(
manager
.
LookupLayer
(
"foo"
,
"char"
,
&
dimension
,
&
pairwise_char_handle
),
test
::
IsErrorWithSubstr
(
"Layer 'char' in component 'foo' does not match "
"its expected OperandType"
));
EXPECT_THAT
(
manager
.
LookupLayer
(
"foo"
,
"int16"
,
&
dimension
,
&
char_handle
),
test
::
IsErrorWithSubstr
(
"Layer 'int16' in component 'foo' does not match its expected type"
));
EXPECT_THAT
(
manager
.
LookupLayer
(
"foo"
,
"int16"
,
&
dimension
,
&
uint16_handle
),
test
::
IsErrorWithSubstr
(
"Layer 'int16' in component 'foo' does not match its expected type"
));
EXPECT_THAT
(
manager
.
LookupLayer
(
"foo"
,
"int16"
,
&
dimension
,
&
pairwise_char_handle
),
test
::
IsErrorWithSubstr
(
"Layer 'int16' in component 'foo' does not match its expected type"
));
EXPECT_THAT
(
manager
.
LookupLayer
(
"bar"
,
"uint16"
,
&
dimension
,
&
char_handle
),
test
::
IsErrorWithSubstr
(
"Layer 'uint16' in component 'bar' does "
"not match its expected type"
));
EXPECT_THAT
(
manager
.
LookupLayer
(
"bar"
,
"uint16"
,
&
dimension
,
&
int16_handle
),
test
::
IsErrorWithSubstr
(
"Layer 'uint16' in component 'bar' does "
"not match its expected type"
));
EXPECT_THAT
(
manager
.
LookupLayer
(
"bar"
,
"uint16"
,
&
dimension
,
&
pairwise_char_handle
),
test
::
IsErrorWithSubstr
(
"Layer 'uint16' in component 'bar' does "
"not match its expected type"
));
EXPECT_THAT
(
manager
.
LookupLayer
(
"bar"
,
"pairwise_char"
,
&
dimension
,
&
char_handle
),
test
::
IsErrorWithSubstr
(
"Layer 'pairwise_char' in component 'bar' does "
"not match its expected OperandType"
));
EXPECT_THAT
(
manager
.
LookupLayer
(
"bar"
,
"pairwise_char"
,
&
dimension
,
&
int16_handle
),
test
::
IsErrorWithSubstr
(
"Layer 'pairwise_char' in component 'bar' does "
"not match its expected type"
));
EXPECT_THAT
(
manager
.
LookupLayer
(
"bar"
,
"pairwise_char"
,
&
dimension
,
&
uint16_handle
),
test
::
IsErrorWithSubstr
(
"Layer 'pairwise_char' in component 'bar' does "
"not match its expected type"
));
// Look up layers properly, and check their dimensions. Also verify that the
// looked-up handles are identical to the original handles.
LayerHandle
<
char
>
lookup_char_handle
;
LayerHandle
<
int16
>
lookup_int16_handle
;
LayerHandle
<
uint16
>
lookup_uint16_handle
;
PairwiseLayerHandle
<
char
>
lookup_pairwise_char_handle
;
TF_EXPECT_OK
(
manager
.
LookupLayer
(
"foo"
,
"char"
,
&
dimension
,
&
lookup_char_handle
));
EXPECT_EQ
(
dimension
,
5
);
ExpectBitwiseEqual
(
lookup_char_handle
,
char_handle
);
TF_EXPECT_OK
(
manager
.
LookupLayer
(
"foo"
,
"int16"
,
&
dimension
,
&
lookup_int16_handle
));
EXPECT_EQ
(
dimension
,
7
);
ExpectBitwiseEqual
(
lookup_int16_handle
,
int16_handle
);
TF_EXPECT_OK
(
manager
.
LookupLayer
(
"bar"
,
"uint16"
,
&
dimension
,
&
lookup_uint16_handle
));
EXPECT_EQ
(
dimension
,
11
);
ExpectBitwiseEqual
(
lookup_uint16_handle
,
uint16_handle
);
TF_EXPECT_OK
(
manager
.
LookupLayer
(
"bar"
,
"pairwise_char"
,
&
dimension
,
&
lookup_pairwise_char_handle
));
EXPECT_EQ
(
dimension
,
13
);
ExpectBitwiseEqual
(
lookup_pairwise_char_handle
,
pairwise_char_handle
);
}
// Tests that NetworkStates cannot start components without a manager.
TEST
(
NetworkStatesTest
,
NoManager
)
{
NetworkStates
network_states
;
EXPECT_THAT
(
network_states
.
StartNextComponent
(
10
),
test
::
IsErrorWithSubstr
(
"No manager"
));
}
// Tests that NetworkStates cannot start components when the manager is empty.
TEST
(
NetworkStatesTest
,
EmptyManager
)
{
NetworkStateManager
empty_manager
;
NetworkStates
network_states
;
network_states
.
Reset
(
&
empty_manager
);
EXPECT_THAT
(
network_states
.
StartNextComponent
(
10
),
test
::
IsErrorWithSubstr
(
"No next component"
));
}
// Tests that NetworkStates can start the same number of components as were
// configured in its manager.
TEST
(
NetworkStatesTest
,
StartNextComponent
)
{
NetworkStateManager
manager
;
TF_EXPECT_OK
(
manager
.
AddComponent
(
"foo"
));
TF_EXPECT_OK
(
manager
.
AddComponent
(
"bar"
));
TF_EXPECT_OK
(
manager
.
AddComponent
(
"baz"
));
NetworkStates
network_states
;
network_states
.
Reset
(
&
manager
);
TF_EXPECT_OK
(
network_states
.
StartNextComponent
(
10
));
TF_EXPECT_OK
(
network_states
.
StartNextComponent
(
11
));
TF_EXPECT_OK
(
network_states
.
StartNextComponent
(
12
));
EXPECT_THAT
(
network_states
.
StartNextComponent
(
13
),
test
::
IsErrorWithSubstr
(
"No next component"
));
}
// Tests that NetworkStates contains layers and locals whose dimensions match
// the configuration of its manager.
TEST
(
NetworkStatesTest
,
Dimensions
)
{
NetworkStateManager
manager
;
// The "foo" component has two layers and a local vector.
LayerHandle
<
float
>
foo_hidden_handle
;
LocalVectorHandle
<
int16
>
foo_local_handle
;
PairwiseLayerHandle
<
float
>
foo_logits_handle
;
TF_ASSERT_OK
(
manager
.
AddComponent
(
"foo"
));
TF_ASSERT_OK
(
manager
.
AddLayer
(
"hidden"
,
10
,
&
foo_hidden_handle
));
TF_ASSERT_OK
(
manager
.
AddLocal
(
20
,
&
foo_local_handle
));
TF_ASSERT_OK
(
manager
.
AddLayer
(
"logits"
,
30
,
&
foo_logits_handle
));
// The "bar" component has one layer and a local matrix.
LayerHandle
<
float
>
bar_logits_handle
;
LocalMatrixHandle
<
bool
>
bar_local_handle
;
TF_ASSERT_OK
(
manager
.
AddComponent
(
"bar"
));
TF_ASSERT_OK
(
manager
.
AddLayer
(
"logits"
,
40
,
&
bar_logits_handle
));
TF_ASSERT_OK
(
manager
.
AddLocal
(
50
,
&
bar_local_handle
));
// Initialize a NetworkStates and check its dimensions. Note that matrices
// start with 0 rows since there are 0 steps.
NetworkStates
network_states
;
network_states
.
Reset
(
&
manager
);
TF_EXPECT_OK
(
network_states
.
StartNextComponent
(
13
));
ExpectDimensions
(
network_states
.
GetLayer
(
foo_hidden_handle
),
0
,
10
);
EXPECT_EQ
(
network_states
.
GetLocal
(
foo_local_handle
).
size
(),
20
);
ExpectDimensions
(
network_states
.
GetLayer
(
foo_logits_handle
),
0
,
0
);
// Add some steps, and check that rows have been added to matrices, while
// vectors are unaffected.
network_states
.
AddSteps
(
19
);
ExpectDimensions
(
network_states
.
GetLayer
(
foo_hidden_handle
),
19
,
10
);
EXPECT_EQ
(
network_states
.
GetLocal
(
foo_local_handle
).
size
(),
20
);
ExpectDimensions
(
network_states
.
GetLayer
(
foo_logits_handle
),
19
,
19
*
30
);
// Again for the next component.
TF_EXPECT_OK
(
network_states
.
StartNextComponent
(
9
));
ExpectDimensions
(
network_states
.
GetLayer
(
bar_logits_handle
),
0
,
40
);
ExpectDimensions
(
network_states
.
GetLocal
(
bar_local_handle
),
0
,
50
);
// Add some steps, and check that rows have been added to matrices.
network_states
.
AddSteps
(
25
);
ExpectDimensions
(
network_states
.
GetLayer
(
bar_logits_handle
),
25
,
40
);
ExpectDimensions
(
network_states
.
GetLocal
(
bar_local_handle
),
25
,
50
);
EXPECT_THAT
(
network_states
.
StartNextComponent
(
10
),
test
::
IsErrorWithSubstr
(
"No next component"
));
// Check the layers of the first component. They should still have the same
// dimensions in spite of adding steps to the second component.
ExpectDimensions
(
network_states
.
GetLayer
(
foo_hidden_handle
),
19
,
10
);
ExpectDimensions
(
network_states
.
GetLayer
(
foo_logits_handle
),
19
,
19
*
30
);
}
// Tests that NetworkStates can be reused by resetting them repeatedly, possibly
// switching between different managers.
TEST
(
NetworkStatesTest
,
ResetWithDifferentManagers
)
{
std
::
vector
<
NetworkStateManager
>
managers
(
10
);
std
::
vector
<
LayerHandle
<
int
>>
layer_handles
(
10
);
std
::
vector
<
PairwiseLayerHandle
<
int
>>
pairwise_layer_handles
(
10
);
std
::
vector
<
LocalVectorHandle
<
int
>>
vector_handles
(
10
);
std
::
vector
<
LocalMatrixHandle
<
double
>>
matrix_handles
(
10
);
for
(
int
dim
=
0
;
dim
<
10
;
++
dim
)
{
TF_ASSERT_OK
(
managers
[
dim
].
AddComponent
(
"foo"
));
TF_ASSERT_OK
(
managers
[
dim
].
AddLayer
(
tensorflow
::
strings
::
StrCat
(
"layer"
,
dim
),
dim
,
&
layer_handles
[
dim
]));
TF_ASSERT_OK
(
managers
[
dim
].
AddLayer
(
tensorflow
::
strings
::
StrCat
(
"pairwise"
,
dim
),
dim
,
&
pairwise_layer_handles
[
dim
]));
TF_ASSERT_OK
(
managers
[
dim
].
AddLocal
(
dim
,
&
vector_handles
[
dim
]));
TF_ASSERT_OK
(
managers
[
dim
].
AddLocal
(
dim
,
&
matrix_handles
[
dim
]));
}
NetworkStates
network_states
;
for
(
int
trial
=
0
;
trial
<
10
;
++
trial
)
{
for
(
int
dim
=
0
;
dim
<
10
;
++
dim
)
{
network_states
.
Reset
(
&
managers
[
dim
]);
TF_ASSERT_OK
(
network_states
.
StartNextComponent
(
10
));
// Fill the vector local.
Fill
(
network_states
.
GetLocal
(
vector_handles
[
dim
]),
dim
,
100
*
trial
+
dim
);
// Check the vector local.
ExpectFilled
(
network_states
.
GetLocal
(
vector_handles
[
dim
]),
dim
,
100
*
trial
+
dim
);
// Repeatedly add a step and fill it with values.
for
(
int
step
=
0
;
step
<
100
;
++
step
)
{
network_states
.
AddStep
();
Fill
(
network_states
.
GetLayer
(
layer_handles
[
dim
]).
row
(
step
),
dim
,
1000
*
trial
+
100
*
dim
+
step
);
Fill
(
network_states
.
GetLocal
(
matrix_handles
[
dim
]).
row
(
step
),
dim
,
9876.0
*
trial
+
100
*
dim
+
step
);
}
// Check that data from earlier steps is preserved across reallocations.
for
(
int
step
=
0
;
step
<
100
;
++
step
)
{
ExpectFilled
(
network_states
.
GetLayer
(
layer_handles
[
dim
]).
row
(
step
),
dim
,
1000
*
trial
+
100
*
dim
+
step
);
ExpectFilled
(
network_states
.
GetLocal
(
matrix_handles
[
dim
]).
row
(
step
),
dim
,
9876.0
*
trial
+
100
*
dim
+
step
);
}
ExpectDimensions
(
network_states
.
GetLayer
(
pairwise_layer_handles
[
dim
]),
100
,
100
*
dim
);
}
}
}
// Tests that one NetworkStateManager can be shared simultaneously between
// multiple NetworkStates instances.
TEST
(
NetworkStatesTest
,
SharedManager
)
{
const
size_t
kDim
=
17
;
NetworkStateManager
manager
;
LayerHandle
<
int
>
layer_handle
;
PairwiseLayerHandle
<
int
>
pairwise_layer_handle
;
LocalVectorHandle
<
int
>
vector_handle
;
LocalMatrixHandle
<
double
>
matrix_handle
;
TF_ASSERT_OK
(
manager
.
AddComponent
(
"foo"
));
TF_ASSERT_OK
(
manager
.
AddLayer
(
"layer"
,
kDim
,
&
layer_handle
));
TF_ASSERT_OK
(
manager
.
AddLayer
(
"pairwise"
,
kDim
,
&
pairwise_layer_handle
));
TF_ASSERT_OK
(
manager
.
AddLocal
(
kDim
,
&
vector_handle
));
TF_ASSERT_OK
(
manager
.
AddLocal
(
kDim
,
&
matrix_handle
));
std
::
vector
<
NetworkStates
>
network_states_vec
(
10
);
for
(
NetworkStates
&
network_states
:
network_states_vec
)
{
network_states
.
Reset
(
&
manager
);
TF_ASSERT_OK
(
network_states
.
StartNextComponent
(
10
));
}
// Fill all vectors.
for
(
int
trial
=
0
;
trial
<
network_states_vec
.
size
();
++
trial
)
{
const
NetworkStates
&
network_states
=
network_states_vec
[
trial
];
Fill
(
network_states
.
GetLocal
(
vector_handle
),
kDim
,
3
*
trial
);
}
// Check all vectors.
for
(
int
trial
=
0
;
trial
<
network_states_vec
.
size
();
++
trial
)
{
const
NetworkStates
&
network_states
=
network_states_vec
[
trial
];
ExpectFilled
(
network_states
.
GetLocal
(
vector_handle
),
kDim
,
3
*
trial
);
}
// Fill all matrices. Interleave operations on the network states on each
// step, so all network states are "active" at the same time.
for
(
int
step
=
0
;
step
<
100
;
++
step
)
{
for
(
int
trial
=
0
;
trial
<
10
;
++
trial
)
{
NetworkStates
&
network_states
=
network_states_vec
[
trial
];
network_states
.
AddStep
();
Fill
(
network_states
.
GetLayer
(
layer_handle
).
row
(
step
),
kDim
,
999
*
trial
+
step
);
Fill
(
network_states
.
GetLocal
(
matrix_handle
).
row
(
step
),
kDim
,
1234.0
*
trial
+
step
);
ExpectDimensions
(
network_states
.
GetLayer
(
pairwise_layer_handle
),
step
+
1
,
kDim
*
(
step
+
1
));
}
}
// Check all matrices.
for
(
int
step
=
0
;
step
<
100
;
++
step
)
{
for
(
int
trial
=
0
;
trial
<
10
;
++
trial
)
{
const
NetworkStates
&
network_states
=
network_states_vec
[
trial
];
ExpectFilled
(
network_states
.
GetLayer
(
layer_handle
).
row
(
step
),
kDim
,
999
*
trial
+
step
);
ExpectFilled
(
network_states
.
GetLocal
(
matrix_handle
).
row
(
step
),
kDim
,
1234.0
*
trial
+
step
);
ExpectDimensions
(
network_states
.
GetLayer
(
pairwise_layer_handle
),
100
,
kDim
*
100
);
}
}
}
}
// namespace
}
// namespace runtime
}
// namespace dragnn
}
// namespace syntaxnet
Prev
1
…
8
9
10
11
12
13
14
15
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment