Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
277f99c7
Commit
277f99c7
authored
Mar 23, 2017
by
Ivan Bogatyy
Committed by
GitHub
Mar 23, 2017
Browse files
Merge pull request #1243 from bogatyy/master
Add license headers, fix some macOS issues
parents
f7cea8d0
ea3fa4a3
Changes
115
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
672 additions
and
48 deletions
+672
-48
syntaxnet/dragnn/components/stateless/BUILD
syntaxnet/dragnn/components/stateless/BUILD
+34
-0
syntaxnet/dragnn/components/stateless/stateless_component.cc
syntaxnet/dragnn/components/stateless/stateless_component.cc
+131
-0
syntaxnet/dragnn/components/stateless/stateless_component_test.cc
...t/dragnn/components/stateless/stateless_component_test.cc
+171
-0
syntaxnet/dragnn/components/syntaxnet/BUILD
syntaxnet/dragnn/components/syntaxnet/BUILD
+8
-9
syntaxnet/dragnn/components/syntaxnet/syntaxnet_component.cc
syntaxnet/dragnn/components/syntaxnet/syntaxnet_component.cc
+15
-0
syntaxnet/dragnn/components/syntaxnet/syntaxnet_component.h
syntaxnet/dragnn/components/syntaxnet/syntaxnet_component.h
+15
-0
syntaxnet/dragnn/components/syntaxnet/syntaxnet_component_test.cc
...t/dragnn/components/syntaxnet/syntaxnet_component_test.cc
+104
-5
syntaxnet/dragnn/components/syntaxnet/syntaxnet_link_feature_extractor.cc
.../components/syntaxnet/syntaxnet_link_feature_extractor.cc
+15
-0
syntaxnet/dragnn/components/syntaxnet/syntaxnet_link_feature_extractor.h
...n/components/syntaxnet/syntaxnet_link_feature_extractor.h
+15
-0
syntaxnet/dragnn/components/syntaxnet/syntaxnet_link_feature_extractor_test.cc
...onents/syntaxnet/syntaxnet_link_feature_extractor_test.cc
+15
-0
syntaxnet/dragnn/components/syntaxnet/syntaxnet_transition_state.cc
...dragnn/components/syntaxnet/syntaxnet_transition_state.cc
+15
-0
syntaxnet/dragnn/components/syntaxnet/syntaxnet_transition_state.h
.../dragnn/components/syntaxnet/syntaxnet_transition_state.h
+15
-0
syntaxnet/dragnn/components/syntaxnet/syntaxnet_transition_state_test.cc
...n/components/syntaxnet/syntaxnet_transition_state_test.cc
+15
-0
syntaxnet/dragnn/components/util/BUILD
syntaxnet/dragnn/components/util/BUILD
+5
-2
syntaxnet/dragnn/components/util/bulk_feature_extractor.h
syntaxnet/dragnn/components/util/bulk_feature_extractor.h
+15
-0
syntaxnet/dragnn/core/BUILD
syntaxnet/dragnn/core/BUILD
+23
-31
syntaxnet/dragnn/core/beam.h
syntaxnet/dragnn/core/beam.h
+16
-1
syntaxnet/dragnn/core/beam_test.cc
syntaxnet/dragnn/core/beam_test.cc
+15
-0
syntaxnet/dragnn/core/component_registry.cc
syntaxnet/dragnn/core/component_registry.cc
+15
-0
syntaxnet/dragnn/core/component_registry.h
syntaxnet/dragnn/core/component_registry.h
+15
-0
No files found.
syntaxnet/dragnn/components/stateless/BUILD
0 → 100644
View file @
277f99c7
package
(
default_visibility
=
[
"//visibility:public"
],
features
=
[
"-layering_check"
],
)
cc_library
(
name
=
"stateless_component"
,
srcs
=
[
"stateless_component.cc"
],
deps
=
[
"//dragnn/core:component_registry"
,
"//dragnn/core/interfaces:component"
,
"//dragnn/core/interfaces:transition_state"
,
"//dragnn/io:sentence_input_batch"
,
"//dragnn/protos:data_proto"
,
"//syntaxnet:base"
,
],
alwayslink
=
1
,
)
cc_test
(
name
=
"stateless_component_test"
,
srcs
=
[
"stateless_component_test.cc"
],
deps
=
[
":stateless_component"
,
"//dragnn/core:component_registry"
,
"//dragnn/core:input_batch_cache"
,
"//dragnn/core/test:generic"
,
"//dragnn/core/test:mock_transition_state"
,
"//dragnn/io:sentence_input_batch"
,
"//syntaxnet:base"
,
"//syntaxnet:sentence_proto"
,
"//syntaxnet:test_main"
,
],
)
syntaxnet/dragnn/components/stateless/stateless_component.cc
0 → 100644
View file @
277f99c7
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/core/component_registry.h"
#include "dragnn/core/interfaces/component.h"
#include "dragnn/core/interfaces/transition_state.h"
#include "dragnn/io/sentence_input_batch.h"
#include "dragnn/protos/data.pb.h"
#include "syntaxnet/base.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
{
// A component that does not create its own transition states; instead, it
// simply forwards the states of the previous component. Does not support all
// methods. Intended for "compute-only" bulk components that only use linked
// features, which use only a small subset of DRAGNN functionality.
class
StatelessComponent
:
public
Component
{
public:
void
InitializeComponent
(
const
ComponentSpec
&
spec
)
override
{
name_
=
spec
.
name
();
}
// Stores the |parent_states| for forwarding to downstream components.
void
InitializeData
(
const
std
::
vector
<
std
::
vector
<
const
TransitionState
*>>
&
parent_states
,
int
max_beam_size
,
InputBatchCache
*
input_data
)
override
{
// Must use SentenceInputBatch to match SyntaxNetComponent.
batch_size_
=
input_data
->
GetAs
<
SentenceInputBatch
>
()
->
data
()
->
size
();
beam_size_
=
max_beam_size
;
parent_states_
=
parent_states
;
// The beam should be wide enough for the previous component.
for
(
const
auto
&
beam
:
parent_states
)
{
CHECK_LE
(
beam
.
size
(),
beam_size_
);
}
}
// Forwards the states of the previous component.
std
::
vector
<
std
::
vector
<
const
TransitionState
*>>
GetBeam
()
override
{
return
parent_states_
;
}
// Forwards the |current_index| to the previous component.
int
GetSourceBeamIndex
(
int
current_index
,
int
batch
)
const
override
{
return
current_index
;
}
string
Name
()
const
override
{
return
name_
;
}
int
BeamSize
()
const
override
{
return
beam_size_
;
}
int
BatchSize
()
const
override
{
return
batch_size_
;
}
int
StepsTaken
(
int
batch_index
)
const
override
{
return
0
;
}
bool
IsReady
()
const
override
{
return
true
;
}
bool
IsTerminal
()
const
override
{
return
true
;
}
void
FinalizeData
()
override
{}
void
ResetComponent
()
override
{}
void
InitializeTracing
()
override
{}
void
DisableTracing
()
override
{}
std
::
vector
<
std
::
vector
<
ComponentTrace
>>
GetTraceProtos
()
const
override
{
return
{};
}
// Unsupported methods.
int
GetBeamIndexAtStep
(
int
step
,
int
current_index
,
int
batch
)
const
override
{
LOG
(
FATAL
)
<<
"["
<<
name_
<<
"] Method not supported"
;
return
0
;
}
std
::
function
<
int
(
int
,
int
,
int
)
>
GetStepLookupFunction
(
const
string
&
method
)
override
{
LOG
(
FATAL
)
<<
"["
<<
name_
<<
"] Method not supported"
;
return
nullptr
;
}
void
AdvanceFromPrediction
(
const
float
transition_matrix
[],
int
matrix_length
)
override
{
LOG
(
FATAL
)
<<
"["
<<
name_
<<
"] Method not supported"
;
}
void
AdvanceFromOracle
()
override
{
LOG
(
FATAL
)
<<
"["
<<
name_
<<
"] Method not supported"
;
}
std
::
vector
<
std
::
vector
<
int
>>
GetOracleLabels
()
const
override
{
LOG
(
FATAL
)
<<
"["
<<
name_
<<
"] Method not supported"
;
return
{};
}
int
GetFixedFeatures
(
std
::
function
<
int32
*
(
int
)
>
allocate_indices
,
std
::
function
<
int64
*
(
int
)
>
allocate_ids
,
std
::
function
<
float
*
(
int
)
>
allocate_weights
,
int
channel_id
)
const
override
{
LOG
(
FATAL
)
<<
"["
<<
name_
<<
"] Method not supported"
;
return
0
;
}
int
BulkGetFixedFeatures
(
const
BulkFeatureExtractor
&
extractor
)
override
{
LOG
(
FATAL
)
<<
"["
<<
name_
<<
"] Method not supported"
;
return
0
;
}
std
::
vector
<
LinkFeatures
>
GetRawLinkFeatures
(
int
channel_id
)
const
override
{
LOG
(
FATAL
)
<<
"["
<<
name_
<<
"] Method not supported"
;
return
{};
}
void
AddTranslatedLinkFeaturesToTrace
(
const
std
::
vector
<
LinkFeatures
>
&
features
,
int
channel_id
)
override
{
LOG
(
FATAL
)
<<
"["
<<
name_
<<
"] Method not supported"
;
}
private:
string
name_
;
// component name
int
batch_size_
=
1
;
// number of sentences in current batch
int
beam_size_
=
1
;
// maximum beam size
// Parent states passed to InitializeData(), and passed along in GetBeam().
std
::
vector
<
std
::
vector
<
const
TransitionState
*>>
parent_states_
;
};
REGISTER_DRAGNN_COMPONENT
(
StatelessComponent
);
}
// namespace
}
// namespace dragnn
}
// namespace syntaxnet
syntaxnet/dragnn/components/stateless/stateless_component_test.cc
0 → 100644
View file @
277f99c7
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/core/component_registry.h"
#include "dragnn/core/input_batch_cache.h"
#include "dragnn/core/test/generic.h"
#include "dragnn/core/test/mock_transition_state.h"
#include "dragnn/io/sentence_input_batch.h"
#include "syntaxnet/base.h"
#include "syntaxnet/sentence.pb.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/platform/test.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
{
const
char
kSentence0
[]
=
R"(
token {
word: "Sentence" start: 0 end: 7 tag: "NN" category: "NOUN" label: "ROOT"
break_level: NO_BREAK
}
token {
word: "0" start: 9 end: 9 head: 0 tag: "CD" category: "NUM" label: "num"
break_level: SPACE_BREAK
}
token {
word: "." start: 10 end: 10 head: 0 tag: "." category: "." label: "punct"
break_level: NO_BREAK
}
)"
;
const
char
kSentence1
[]
=
R"(
token {
word: "Sentence" start: 0 end: 7 tag: "NN" category: "NOUN" label: "ROOT"
break_level: NO_BREAK
}
token {
word: "1" start: 9 end: 9 head: 0 tag: "CD" category: "NUM" label: "num"
break_level: SPACE_BREAK
}
token {
word: "." start: 10 end: 10 head: 0 tag: "." category: "." label: "punct"
break_level: NO_BREAK
}
)"
;
const
char
kLongSentence
[]
=
R"(
token {
word: "Sentence" start: 0 end: 7 tag: "NN" category: "NOUN" label: "ROOT"
break_level: NO_BREAK
}
token {
word: "1" start: 9 end: 9 head: 0 tag: "CD" category: "NUM" label: "num"
break_level: SPACE_BREAK
}
token {
word: "2" start: 10 end: 10 head: 0 tag: "CD" category: "NUM" label: "num"
break_level: SPACE_BREAK
}
token {
word: "3" start: 11 end: 11 head: 0 tag: "CD" category: "NUM" label: "num"
break_level: SPACE_BREAK
}
token {
word: "." start: 12 end: 12 head: 0 tag: "." category: "." label: "punct"
break_level: NO_BREAK
}
)"
;
const
char
kMasterSpec
[]
=
R"(
component {
name: "test"
transition_system {
registered_name: "shift-only"
}
linked_feature {
name: "prev"
fml: "input.focus"
embedding_dim: 32
size: 1
source_component: "prev"
source_translator: "identity"
source_layer: "last_layer"
}
backend {
registered_name: "StatelessComponent"
}
}
)"
;
}
// namespace
using
testing
::
Return
;
class
StatelessComponentTest
:
public
::
testing
::
Test
{
public:
std
::
unique_ptr
<
Component
>
CreateParser
(
int
beam_size
,
const
std
::
vector
<
std
::
vector
<
const
TransitionState
*>>
&
states
,
const
std
::
vector
<
string
>
&
data
)
{
MasterSpec
master_spec
;
CHECK
(
TextFormat
::
ParseFromString
(
kMasterSpec
,
&
master_spec
));
data_
.
reset
(
new
InputBatchCache
(
data
));
// Create a parser component with the specified beam size.
std
::
unique_ptr
<
Component
>
parser_component
(
Component
::
Create
(
"StatelessComponent"
));
parser_component
->
InitializeComponent
(
master_spec
.
component
(
0
));
parser_component
->
InitializeData
(
states
,
beam_size
,
data_
.
get
());
return
parser_component
;
}
std
::
unique_ptr
<
InputBatchCache
>
data_
;
};
TEST_F
(
StatelessComponentTest
,
ForwardsTransitionStates
)
{
const
MockTransitionState
mock_state_1
,
mock_state_2
,
mock_state_3
;
const
std
::
vector
<
std
::
vector
<
const
TransitionState
*>>
parent_states
=
{
{},
{
&
mock_state_1
},
{
&
mock_state_2
,
&
mock_state_3
}};
std
::
vector
<
string
>
data
;
for
(
const
string
&
textproto
:
{
kSentence0
,
kSentence1
,
kLongSentence
})
{
Sentence
sentence
;
CHECK
(
TextFormat
::
ParseFromString
(
textproto
,
&
sentence
));
data
.
emplace_back
();
CHECK
(
sentence
.
SerializeToString
(
&
data
.
back
()));
}
CHECK_EQ
(
parent_states
.
size
(),
data
.
size
());
const
int
kBeamSize
=
2
;
auto
test_parser
=
CreateParser
(
kBeamSize
,
parent_states
,
data
);
EXPECT_TRUE
(
test_parser
->
IsReady
());
EXPECT_TRUE
(
test_parser
->
IsTerminal
());
EXPECT_EQ
(
kBeamSize
,
test_parser
->
BeamSize
());
EXPECT_EQ
(
data
.
size
(),
test_parser
->
BatchSize
());
EXPECT_TRUE
(
test_parser
->
GetTraceProtos
().
empty
());
for
(
int
batch_index
=
0
;
batch_index
<
parent_states
.
size
();
++
batch_index
)
{
EXPECT_EQ
(
0
,
test_parser
->
StepsTaken
(
batch_index
));
const
auto
&
beam
=
parent_states
[
batch_index
];
for
(
int
beam_index
=
0
;
beam_index
<
beam
.
size
();
++
beam_index
)
{
// Expect an identity mapping.
EXPECT_EQ
(
beam_index
,
test_parser
->
GetSourceBeamIndex
(
beam_index
,
batch_index
));
}
}
const
auto
forwarded_states
=
test_parser
->
GetBeam
();
EXPECT_EQ
(
parent_states
,
forwarded_states
);
}
}
// namespace dragnn
}
// namespace syntaxnet
syntaxnet/dragnn/components/syntaxnet/BUILD
View file @
277f99c7
package
(
default_visibility
=
[
"//visibility:public"
])
package
(
default_visibility
=
[
"//visibility:public"
],
features
=
[
"-layering_check"
],
)
cc_library
(
name
=
"syntaxnet_component"
,
...
...
@@ -25,7 +28,6 @@ cc_library(
"//syntaxnet:task_context"
,
"//syntaxnet:task_spec_proto"
,
"//syntaxnet:utils"
,
"@org_tensorflow//tensorflow/core:lib"
,
# For tf/core/platform/logging.h
],
alwayslink
=
1
,
)
...
...
@@ -36,10 +38,10 @@ cc_library(
hdrs
=
[
"syntaxnet_link_feature_extractor.h"
],
deps
=
[
"//dragnn/protos:spec_proto"
,
"//syntaxnet:base"
,
"//syntaxnet:embedding_feature_extractor"
,
"//syntaxnet:parser_transitions"
,
"//syntaxnet:task_context"
,
"@org_tensorflow//tensorflow/core:lib"
,
# For tf/core/platform/logging.h
],
)
...
...
@@ -54,7 +56,6 @@ cc_library(
"//dragnn/protos:trace_proto"
,
"//syntaxnet:base"
,
"//syntaxnet:parser_transitions"
,
"@org_tensorflow//tensorflow/core:lib"
,
# For tf/core/platform/logging.h
],
)
...
...
@@ -75,9 +76,9 @@ cc_test(
"//dragnn/core/test:generic"
,
"//dragnn/core/test:mock_transition_state"
,
"//dragnn/io:sentence_input_batch"
,
"//syntaxnet:base"
,
"//syntaxnet:sentence_proto"
,
"@org_tensorflow//tensorflow/core:lib"
,
"@org_tensorflow//tensorflow/core:test"
,
"//syntaxnet:test_main"
,
],
)
...
...
@@ -90,7 +91,6 @@ cc_test(
"//dragnn/protos:spec_proto"
,
"//syntaxnet:task_context"
,
"//syntaxnet:test_main"
,
"@org_tensorflow//tensorflow/core:test"
,
"@org_tensorflow//tensorflow/core:testlib"
,
],
)
...
...
@@ -107,10 +107,9 @@ cc_test(
"//dragnn/core/test:mock_transition_state"
,
"//dragnn/io:sentence_input_batch"
,
"//dragnn/protos:spec_proto"
,
"//syntaxnet:base"
,
"//syntaxnet:sentence_proto"
,
"//syntaxnet:test_main"
,
"@org_tensorflow//tensorflow/core:lib"
,
"@org_tensorflow//tensorflow/core:test"
,
"@org_tensorflow//tensorflow/core:testlib"
,
],
)
syntaxnet/dragnn/components/syntaxnet/syntaxnet_component.cc
View file @
277f99c7
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/components/syntaxnet/syntaxnet_component.h"
#include <vector>
...
...
syntaxnet/dragnn/components/syntaxnet/syntaxnet_component.h
View file @
277f99c7
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#ifndef NLP_SAFT_OPENSOURCE_DRAGNN_COMPONENTS_SYNTAXNET_SYNTAXNET_COMPONENT_H_
#define NLP_SAFT_OPENSOURCE_DRAGNN_COMPONENTS_SYNTAXNET_SYNTAXNET_COMPONENT_H_
...
...
syntaxnet/dragnn/components/syntaxnet/syntaxnet_component_test.cc
View file @
277f99c7
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/components/syntaxnet/syntaxnet_component.h"
#include "dragnn/core/input_batch_cache.h"
...
...
@@ -807,6 +822,90 @@ TEST_F(SyntaxNetComponentTest, ExportsFixedFeatures) {
transition_matrix
[
i
]
=
kTransitionValue
;
}
// Advance twice, so that the underlying parser fills the beam.
test_parser
->
AdvanceFromPrediction
(
transition_matrix
,
kNumPossibleTransitions
*
kBeamSize
*
kBatchSize
);
test_parser
->
AdvanceFromPrediction
(
transition_matrix
,
kNumPossibleTransitions
*
kBeamSize
*
kBatchSize
);
// Get and check the raw link features.
vector
<
int32
>
indices
;
auto
indices_fn
=
[
&
indices
](
int
size
)
{
indices
.
resize
(
size
);
return
indices
.
data
();
};
vector
<
int64
>
ids
;
auto
ids_fn
=
[
&
ids
](
int
size
)
{
ids
.
resize
(
size
);
return
ids
.
data
();
};
vector
<
float
>
weights
;
auto
weights_fn
=
[
&
weights
](
int
size
)
{
weights
.
resize
(
size
);
return
weights
.
data
();
};
constexpr
int
kChannelId
=
0
;
const
int
num_features
=
test_parser
->
GetFixedFeatures
(
indices_fn
,
ids_fn
,
weights_fn
,
kChannelId
);
constexpr
int
kExpectedOutputSize
=
12
;
const
vector
<
int32
>
expected_indices
({
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
,
11
});
const
vector
<
int64
>
expected_ids
({
7
,
50
,
12
,
7
,
12
,
7
,
7
,
50
,
12
,
7
,
12
,
7
});
const
vector
<
float
>
expected_weights
(
{
1.0
,
1.0
,
1.0
,
1.0
,
1.0
,
1.0
,
1.0
,
1.0
,
1.0
,
1.0
,
1.0
,
1.0
});
EXPECT_EQ
(
expected_indices
.
size
(),
kExpectedOutputSize
);
EXPECT_EQ
(
expected_ids
.
size
(),
kExpectedOutputSize
);
EXPECT_EQ
(
expected_weights
.
size
(),
kExpectedOutputSize
);
EXPECT_EQ
(
num_features
,
kExpectedOutputSize
);
EXPECT_EQ
(
expected_indices
,
indices
);
EXPECT_EQ
(
expected_ids
,
ids
);
EXPECT_EQ
(
expected_weights
,
weights
);
}
TEST_F
(
SyntaxNetComponentTest
,
AdvancesAccordingToHighestWeightedInputOption
)
{
// Create an empty input batch and beam vector to initialize the parser.
Sentence
sentence_0
;
TextFormat
::
ParseFromString
(
kSentence0
,
&
sentence_0
);
string
sentence_0_str
;
sentence_0
.
SerializeToString
(
&
sentence_0_str
);
Sentence
sentence_1
;
TextFormat
::
ParseFromString
(
kSentence1
,
&
sentence_1
);
string
sentence_1_str
;
sentence_1
.
SerializeToString
(
&
sentence_1_str
);
constexpr
int
kBeamSize
=
3
;
auto
test_parser
=
CreateParserWithBeamSize
(
kBeamSize
,
{},
{
sentence_0_str
,
sentence_1_str
});
// There are 93 possible transitions for any given state. Create a transition
// array with a score of 10.0 for each transition.
constexpr
int
kBatchSize
=
2
;
constexpr
int
kNumPossibleTransitions
=
93
;
constexpr
float
kTransitionValue
=
10.0
;
float
transition_matrix
[
kNumPossibleTransitions
*
kBeamSize
*
kBatchSize
];
for
(
int
i
=
0
;
i
<
kNumPossibleTransitions
*
kBeamSize
*
kBatchSize
;
++
i
)
{
transition_matrix
[
i
]
=
kTransitionValue
;
}
// Replace the first several options with varying scores to test sorting.
constexpr
int
kBatchOffset
=
kNumPossibleTransitions
*
kBeamSize
;
transition_matrix
[
0
]
=
3
*
kTransitionValue
;
transition_matrix
[
1
]
=
3
*
kTransitionValue
;
transition_matrix
[
2
]
=
4
*
kTransitionValue
;
transition_matrix
[
3
]
=
4
*
kTransitionValue
;
transition_matrix
[
4
]
=
2
*
kTransitionValue
;
transition_matrix
[
5
]
=
2
*
kTransitionValue
;
transition_matrix
[
kBatchOffset
+
0
]
=
3
*
kTransitionValue
;
transition_matrix
[
kBatchOffset
+
1
]
=
3
*
kTransitionValue
;
transition_matrix
[
kBatchOffset
+
2
]
=
4
*
kTransitionValue
;
transition_matrix
[
kBatchOffset
+
3
]
=
4
*
kTransitionValue
;
transition_matrix
[
kBatchOffset
+
4
]
=
2
*
kTransitionValue
;
transition_matrix
[
kBatchOffset
+
5
]
=
2
*
kTransitionValue
;
// Advance twice, so that the underlying parser fills the beam.
test_parser
->
AdvanceFromPrediction
(
transition_matrix
,
kNumPossibleTransitions
*
kBeamSize
*
kBatchSize
);
...
...
@@ -836,7 +935,7 @@ TEST_F(SyntaxNetComponentTest, ExportsFixedFeatures) {
// In this case, all even features and all odd features are identical.
constexpr
int
kExpectedOutputSize
=
12
;
const
vector
<
int32
>
expected_indices
({
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
,
11
});
const
vector
<
int64
>
expected_ids
({
12
,
7
,
12
,
7
,
12
,
7
,
12
,
7
,
12
,
7
,
12
,
7
});
const
vector
<
int64
>
expected_ids
({
12
,
7
,
7
,
50
,
12
,
7
,
12
,
7
,
7
,
50
,
12
,
7
});
const
vector
<
float
>
expected_weights
(
{
1.0
,
1.0
,
1.0
,
1.0
,
1.0
,
1.0
,
1.0
,
1.0
,
1.0
,
1.0
,
1.0
,
1.0
});
...
...
@@ -1024,11 +1123,11 @@ TEST_F(SyntaxNetComponentTest, ExportsRawLinkFeatures) {
EXPECT_EQ
(
link_features
.
size
(),
kBeamSize
*
kBatchSize
*
kNumLinkFeatures
);
// These should index into batch 0.
EXPECT_EQ
(
link_features
.
at
(
0
).
feature_value
(),
-
1
);
EXPECT_EQ
(
link_features
.
at
(
0
).
feature_value
(),
1
);
EXPECT_EQ
(
link_features
.
at
(
0
).
batch_idx
(),
0
);
EXPECT_EQ
(
link_features
.
at
(
0
).
beam_idx
(),
0
);
EXPECT_EQ
(
link_features
.
at
(
1
).
feature_value
(),
-
2
);
EXPECT_EQ
(
link_features
.
at
(
1
).
feature_value
(),
0
);
EXPECT_EQ
(
link_features
.
at
(
1
).
batch_idx
(),
0
);
EXPECT_EQ
(
link_features
.
at
(
1
).
beam_idx
(),
0
);
...
...
@@ -1049,11 +1148,11 @@ TEST_F(SyntaxNetComponentTest, ExportsRawLinkFeatures) {
EXPECT_EQ
(
link_features
.
at
(
5
).
beam_idx
(),
2
);
// These should index into batch 1.
EXPECT_EQ
(
link_features
.
at
(
6
).
feature_value
(),
-
1
);
EXPECT_EQ
(
link_features
.
at
(
6
).
feature_value
(),
1
);
EXPECT_EQ
(
link_features
.
at
(
6
).
batch_idx
(),
1
);
EXPECT_EQ
(
link_features
.
at
(
6
).
beam_idx
(),
0
);
EXPECT_EQ
(
link_features
.
at
(
7
).
feature_value
(),
-
2
);
EXPECT_EQ
(
link_features
.
at
(
7
).
feature_value
(),
0
);
EXPECT_EQ
(
link_features
.
at
(
7
).
batch_idx
(),
1
);
EXPECT_EQ
(
link_features
.
at
(
7
).
beam_idx
(),
0
);
...
...
syntaxnet/dragnn/components/syntaxnet/syntaxnet_link_feature_extractor.cc
View file @
277f99c7
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/components/syntaxnet/syntaxnet_link_feature_extractor.h"
#include "tensorflow/core/platform/logging.h"
...
...
syntaxnet/dragnn/components/syntaxnet/syntaxnet_link_feature_extractor.h
View file @
277f99c7
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#ifndef NLP_SAFT_OPENSOURCE_DRAGNN_COMPONENTS_SYNTAXNET_SYNTAXNET_LINK_FEATURE_EXTRACTOR_H_
#define NLP_SAFT_OPENSOURCE_DRAGNN_COMPONENTS_SYNTAXNET_SYNTAXNET_LINK_FEATURE_EXTRACTOR_H_
...
...
syntaxnet/dragnn/components/syntaxnet/syntaxnet_link_feature_extractor_test.cc
View file @
277f99c7
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/components/syntaxnet/syntaxnet_link_feature_extractor.h"
#include <string>
...
...
syntaxnet/dragnn/components/syntaxnet/syntaxnet_transition_state.cc
View file @
277f99c7
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/components/syntaxnet/syntaxnet_transition_state.h"
#include "tensorflow/core/lib/strings/strcat.h"
...
...
syntaxnet/dragnn/components/syntaxnet/syntaxnet_transition_state.h
View file @
277f99c7
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#ifndef NLP_SAFT_OPENSOURCE_DRAGNN_COMPONENTS_SYNTAXNET_SYNTAXNET_TRANSITION_STATE_H_
#define NLP_SAFT_OPENSOURCE_DRAGNN_COMPONENTS_SYNTAXNET_SYNTAXNET_TRANSITION_STATE_H_
...
...
syntaxnet/dragnn/components/syntaxnet/syntaxnet_transition_state_test.cc
View file @
277f99c7
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/components/syntaxnet/syntaxnet_transition_state.h"
#include "dragnn/components/syntaxnet/syntaxnet_component.h"
...
...
syntaxnet/dragnn/components/util/BUILD
View file @
277f99c7
package
(
default_visibility
=
[
"//visibility:public"
])
package
(
default_visibility
=
[
"//visibility:public"
],
features
=
[
"-layering_check"
],
)
cc_library
(
name
=
"bulk_feature_extractor"
,
hdrs
=
[
"bulk_feature_extractor.h"
],
deps
=
[
"
@org_tensorflow//tensorflow/core:lib
"
,
"
//syntaxnet:base
"
,
],
)
syntaxnet/dragnn/components/util/bulk_feature_extractor.h
View file @
277f99c7
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#ifndef NLP_SAFT_OPENSOURCE_DRAGNN_COMPONENTS_UTIL_BULK_FEATURE_EXTRACTOR_H_
#define NLP_SAFT_OPENSOURCE_DRAGNN_COMPONENTS_UTIL_BULK_FEATURE_EXTRACTOR_H_
...
...
syntaxnet/dragnn/core/BUILD
View file @
277f99c7
package
(
default_visibility
=
[
"//visibility:public"
])
package
(
default_visibility
=
[
"//visibility:public"
],
features
=
[
"-layering_check"
],
)
# Test data.
filegroup
(
...
...
@@ -12,7 +15,7 @@ cc_library(
deps
=
[
"//dragnn/core/interfaces:cloneable_transition_state"
,
"//dragnn/core/interfaces:transition_state"
,
"
@org_tensorflow//tensorflow/core:lib"
,
# For tf/core/platform/logging.h
"
//syntaxnet:base"
,
],
)
...
...
@@ -50,8 +53,8 @@ cc_library(
"//dragnn/protos:data_proto"
,
"//dragnn/protos:spec_proto"
,
"//dragnn/protos:trace_proto"
,
"//syntaxnet:base"
,
"//syntaxnet:registry"
,
"@org_tensorflow//tensorflow/core:lib"
,
# For tf/core/platform/logging.h
],
)
...
...
@@ -64,7 +67,7 @@ cc_library(
":compute_session"
,
":compute_session_impl"
,
"//dragnn/protos:spec_proto"
,
"
@org_tensorflow//tensorflow/core:lib
"
,
"
//syntaxnet:base
"
,
],
)
...
...
@@ -75,7 +78,7 @@ cc_library(
deps
=
[
"//dragnn/core/interfaces:component"
,
"//dragnn/core/interfaces:transition_state"
,
"
@org_tensorflow//tensorflow/core:lib"
,
# For tf/core/platform/logging.h
"
//syntaxnet:base"
,
],
)
...
...
@@ -84,17 +87,14 @@ cc_library(
hdrs
=
[
"input_batch_cache.h"
],
deps
=
[
"//dragnn/core/interfaces:input_batch"
,
"
@org_tensorflow//tensorflow/core:lib"
,
# For tf/core/platform/logging.h
"
//syntaxnet:base"
,
],
)
cc_library
(
name
=
"resource_container"
,
hdrs
=
[
"resource_container.h"
],
deps
=
[
"//syntaxnet:base"
,
"@org_tensorflow//tensorflow/core:framework"
,
],
deps
=
[
"//syntaxnet:base"
],
)
# Tests
...
...
@@ -107,8 +107,8 @@ cc_test(
"//dragnn/core/interfaces:cloneable_transition_state"
,
"//dragnn/core/interfaces:transition_state"
,
"//dragnn/core/test:mock_transition_state"
,
"//syntaxnet:base"
,
"//syntaxnet:test_main"
,
"@org_tensorflow//tensorflow/core:test"
,
],
)
...
...
@@ -125,7 +125,7 @@ cc_test(
"//dragnn/core/test:generic"
,
"//dragnn/core/test:mock_component"
,
"//dragnn/core/test:mock_transition_state"
,
"
@org_tensorflow//tensorflow/core:test
"
,
"
//syntaxnet:base
"
,
],
)
...
...
@@ -138,9 +138,8 @@ cc_test(
"//dragnn/core/test:generic"
,
"//dragnn/core/test:mock_component"
,
"//dragnn/core/test:mock_compute_session"
,
"//syntaxnet:base"
,
"//syntaxnet:test_main"
,
"@org_tensorflow//tensorflow/core:lib"
,
"@org_tensorflow//tensorflow/core:test"
,
],
)
...
...
@@ -151,8 +150,8 @@ cc_test(
":index_translator"
,
"//dragnn/core/test:mock_component"
,
"//dragnn/core/test:mock_transition_state"
,
"//syntaxnet:base"
,
"//syntaxnet:test_main"
,
"@org_tensorflow//tensorflow/core:test"
,
],
)
...
...
@@ -162,8 +161,8 @@ cc_test(
deps
=
[
":input_batch_cache"
,
"//dragnn/core/interfaces:input_batch"
,
"//syntaxnet:base"
,
"//syntaxnet:test_main"
,
"@org_tensorflow//tensorflow/core:test"
,
],
)
...
...
@@ -172,8 +171,8 @@ cc_test(
srcs
=
[
"resource_container_test.cc"
],
deps
=
[
":resource_container"
,
"//syntaxnet:base"
,
"//syntaxnet:test_main"
,
"@org_tensorflow//tensorflow/core:test"
,
],
)
...
...
@@ -213,7 +212,7 @@ cc_library(
deps
=
[
":compute_session"
,
":resource_container"
,
"
@org_tensorflow//tensorflow/core:framework
"
,
"
//syntaxnet:base
"
,
"@org_tensorflow//third_party/eigen3"
,
],
)
...
...
@@ -231,8 +230,7 @@ cc_library(
":resource_container"
,
"//dragnn/protos:data_proto"
,
"//dragnn/protos:spec_proto"
,
"@org_tensorflow//tensorflow/core:framework"
,
"@org_tensorflow//tensorflow/core:lib"
,
"//syntaxnet:base"
,
"@org_tensorflow//third_party/eigen3"
,
],
alwayslink
=
1
,
...
...
@@ -247,8 +245,7 @@ cc_library(
deps
=
[
":compute_session_op"
,
":resource_container"
,
"@org_tensorflow//tensorflow/core:framework"
,
"@org_tensorflow//tensorflow/core:lib"
,
"//syntaxnet:base"
,
"@org_tensorflow//tensorflow/core:protos_all_cc"
,
"@org_tensorflow//third_party/eigen3"
,
],
...
...
@@ -271,8 +268,7 @@ tf_kernel_library(
":resource_container"
,
"//dragnn/protos:data_proto"
,
"//dragnn/protos:spec_proto"
,
"@org_tensorflow//tensorflow/core:framework"
,
"@org_tensorflow//tensorflow/core:lib"
,
"//syntaxnet:base"
,
"@org_tensorflow//third_party/eigen3"
,
],
)
...
...
@@ -292,8 +288,7 @@ tf_kernel_library(
":resource_container"
,
"//dragnn/components/util:bulk_feature_extractor"
,
"//dragnn/protos:spec_proto"
,
"@org_tensorflow//tensorflow/core:framework"
,
"@org_tensorflow//tensorflow/core:lib"
,
"//syntaxnet:base"
,
"@org_tensorflow//tensorflow/core:protos_all_cc"
,
"@org_tensorflow//third_party/eigen3"
,
],
...
...
@@ -311,11 +306,9 @@ cc_test(
":resource_container"
,
"//dragnn/core/test:generic"
,
"//dragnn/core/test:mock_compute_session"
,
"//syntaxnet:base"
,
"//syntaxnet:test_main"
,
"@org_tensorflow//tensorflow/core:framework"
,
"@org_tensorflow//tensorflow/core:protos_all_cc"
,
"@org_tensorflow//tensorflow/core:test"
,
"@org_tensorflow//tensorflow/core:testlib"
,
"@org_tensorflow//tensorflow/core/kernels:ops_testutil"
,
"@org_tensorflow//tensorflow/core/kernels:ops_util"
,
"@org_tensorflow//tensorflow/core/kernels:quantized_ops"
,
...
...
@@ -331,9 +324,8 @@ cc_test(
":resource_container"
,
"//dragnn/components/util:bulk_feature_extractor"
,
"//dragnn/core/test:mock_compute_session"
,
"//syntaxnet:base"
,
"//syntaxnet:test_main"
,
"@org_tensorflow//tensorflow/core:framework"
,
"@org_tensorflow//tensorflow/core:testlib"
,
"@org_tensorflow//tensorflow/core/kernels:ops_testutil"
,
"@org_tensorflow//tensorflow/core/kernels:quantized_ops"
,
],
...
...
syntaxnet/dragnn/core/beam.h
View file @
277f99c7
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#ifndef NLP_SAFT_OPENSOURCE_DRAGNN_CORE_BEAM_H_
#define NLP_SAFT_OPENSOURCE_DRAGNN_CORE_BEAM_H_
...
...
@@ -126,7 +141,7 @@ class Beam {
const
auto
comparator
=
[](
const
Transition
&
a
,
const
Transition
&
b
)
{
return
a
.
resulting_score
>
b
.
resulting_score
;
};
std
::
sort
(
candidates
.
begin
(),
candidates
.
end
(),
comparator
);
std
::
stable_
sort
(
candidates
.
begin
(),
candidates
.
end
(),
comparator
);
// Apply the top transitions, up to a maximum of 'max_size_'.
std
::
vector
<
std
::
unique_ptr
<
T
>>
new_beam
;
...
...
syntaxnet/dragnn/core/beam_test.cc
View file @
277f99c7
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/core/beam.h"
#include "dragnn/core/interfaces/cloneable_transition_state.h"
...
...
syntaxnet/dragnn/core/component_registry.cc
View file @
277f99c7
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/core/component_registry.h"
namespace
syntaxnet
{
...
...
syntaxnet/dragnn/core/component_registry.h
View file @
277f99c7
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#ifndef NLP_SAFT_OPENSOURCE_DRAGNN_CORE_COMPONENT_REGISTRY_H_
#define NLP_SAFT_OPENSOURCE_DRAGNN_CORE_COMPONENT_REGISTRY_H_
...
...
Prev
1
2
3
4
5
6
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment