Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
ea3fa4a3
"llama/llama.cpp/src/unicode.cpp" did not exist on "f2890a4494f9fb3722ee7a4c506252362d1eab65"
Commit
ea3fa4a3
authored
Mar 22, 2017
by
Ivan Bogatyy
Browse files
Update DRAGNN, fix some macOS issues
parent
b7523ee5
Changes
115
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
674 additions
and
49 deletions
+674
-49
syntaxnet/dragnn/components/stateless/BUILD
syntaxnet/dragnn/components/stateless/BUILD
+34
-0
syntaxnet/dragnn/components/stateless/stateless_component.cc
syntaxnet/dragnn/components/stateless/stateless_component.cc
+131
-0
syntaxnet/dragnn/components/stateless/stateless_component_test.cc
...t/dragnn/components/stateless/stateless_component_test.cc
+171
-0
syntaxnet/dragnn/components/syntaxnet/BUILD
syntaxnet/dragnn/components/syntaxnet/BUILD
+8
-9
syntaxnet/dragnn/components/syntaxnet/syntaxnet_component.cc
syntaxnet/dragnn/components/syntaxnet/syntaxnet_component.cc
+15
-0
syntaxnet/dragnn/components/syntaxnet/syntaxnet_component.h
syntaxnet/dragnn/components/syntaxnet/syntaxnet_component.h
+15
-0
syntaxnet/dragnn/components/syntaxnet/syntaxnet_component_test.cc
...t/dragnn/components/syntaxnet/syntaxnet_component_test.cc
+104
-5
syntaxnet/dragnn/components/syntaxnet/syntaxnet_link_feature_extractor.cc
.../components/syntaxnet/syntaxnet_link_feature_extractor.cc
+15
-0
syntaxnet/dragnn/components/syntaxnet/syntaxnet_link_feature_extractor.h
...n/components/syntaxnet/syntaxnet_link_feature_extractor.h
+15
-0
syntaxnet/dragnn/components/syntaxnet/syntaxnet_link_feature_extractor_test.cc
...onents/syntaxnet/syntaxnet_link_feature_extractor_test.cc
+15
-0
syntaxnet/dragnn/components/syntaxnet/syntaxnet_transition_state.cc
...dragnn/components/syntaxnet/syntaxnet_transition_state.cc
+15
-0
syntaxnet/dragnn/components/syntaxnet/syntaxnet_transition_state.h
.../dragnn/components/syntaxnet/syntaxnet_transition_state.h
+15
-0
syntaxnet/dragnn/components/syntaxnet/syntaxnet_transition_state_test.cc
...n/components/syntaxnet/syntaxnet_transition_state_test.cc
+15
-0
syntaxnet/dragnn/components/util/BUILD
syntaxnet/dragnn/components/util/BUILD
+5
-2
syntaxnet/dragnn/components/util/bulk_feature_extractor.h
syntaxnet/dragnn/components/util/bulk_feature_extractor.h
+15
-0
syntaxnet/dragnn/core/BUILD
syntaxnet/dragnn/core/BUILD
+23
-31
syntaxnet/dragnn/core/beam.h
syntaxnet/dragnn/core/beam.h
+18
-2
syntaxnet/dragnn/core/beam_test.cc
syntaxnet/dragnn/core/beam_test.cc
+15
-0
syntaxnet/dragnn/core/component_registry.cc
syntaxnet/dragnn/core/component_registry.cc
+15
-0
syntaxnet/dragnn/core/component_registry.h
syntaxnet/dragnn/core/component_registry.h
+15
-0
No files found.
syntaxnet/dragnn/components/stateless/BUILD
0 → 100644
View file @
ea3fa4a3
package
(
default_visibility
=
[
"//visibility:public"
],
features
=
[
"-layering_check"
],
)
cc_library
(
name
=
"stateless_component"
,
srcs
=
[
"stateless_component.cc"
],
deps
=
[
"//dragnn/core:component_registry"
,
"//dragnn/core/interfaces:component"
,
"//dragnn/core/interfaces:transition_state"
,
"//dragnn/io:sentence_input_batch"
,
"//dragnn/protos:data_proto"
,
"//syntaxnet:base"
,
],
alwayslink
=
1
,
)
cc_test
(
name
=
"stateless_component_test"
,
srcs
=
[
"stateless_component_test.cc"
],
deps
=
[
":stateless_component"
,
"//dragnn/core:component_registry"
,
"//dragnn/core:input_batch_cache"
,
"//dragnn/core/test:generic"
,
"//dragnn/core/test:mock_transition_state"
,
"//dragnn/io:sentence_input_batch"
,
"//syntaxnet:base"
,
"//syntaxnet:sentence_proto"
,
"//syntaxnet:test_main"
,
],
)
syntaxnet/dragnn/components/stateless/stateless_component.cc
0 → 100644
View file @
ea3fa4a3
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/core/component_registry.h"
#include "dragnn/core/interfaces/component.h"
#include "dragnn/core/interfaces/transition_state.h"
#include "dragnn/io/sentence_input_batch.h"
#include "dragnn/protos/data.pb.h"
#include "syntaxnet/base.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
{
// A component that does not create its own transition states; instead, it
// simply forwards the states of the previous component. Does not support all
// methods. Intended for "compute-only" bulk components that only use linked
// features, which use only a small subset of DRAGNN functionality.
class
StatelessComponent
:
public
Component
{
public:
void
InitializeComponent
(
const
ComponentSpec
&
spec
)
override
{
name_
=
spec
.
name
();
}
// Stores the |parent_states| for forwarding to downstream components.
void
InitializeData
(
const
std
::
vector
<
std
::
vector
<
const
TransitionState
*>>
&
parent_states
,
int
max_beam_size
,
InputBatchCache
*
input_data
)
override
{
// Must use SentenceInputBatch to match SyntaxNetComponent.
batch_size_
=
input_data
->
GetAs
<
SentenceInputBatch
>
()
->
data
()
->
size
();
beam_size_
=
max_beam_size
;
parent_states_
=
parent_states
;
// The beam should be wide enough for the previous component.
for
(
const
auto
&
beam
:
parent_states
)
{
CHECK_LE
(
beam
.
size
(),
beam_size_
);
}
}
// Forwards the states of the previous component.
std
::
vector
<
std
::
vector
<
const
TransitionState
*>>
GetBeam
()
override
{
return
parent_states_
;
}
// Forwards the |current_index| to the previous component.
int
GetSourceBeamIndex
(
int
current_index
,
int
batch
)
const
override
{
return
current_index
;
}
string
Name
()
const
override
{
return
name_
;
}
int
BeamSize
()
const
override
{
return
beam_size_
;
}
int
BatchSize
()
const
override
{
return
batch_size_
;
}
int
StepsTaken
(
int
batch_index
)
const
override
{
return
0
;
}
bool
IsReady
()
const
override
{
return
true
;
}
bool
IsTerminal
()
const
override
{
return
true
;
}
void
FinalizeData
()
override
{}
void
ResetComponent
()
override
{}
void
InitializeTracing
()
override
{}
void
DisableTracing
()
override
{}
std
::
vector
<
std
::
vector
<
ComponentTrace
>>
GetTraceProtos
()
const
override
{
return
{};
}
// Unsupported methods.
int
GetBeamIndexAtStep
(
int
step
,
int
current_index
,
int
batch
)
const
override
{
LOG
(
FATAL
)
<<
"["
<<
name_
<<
"] Method not supported"
;
return
0
;
}
std
::
function
<
int
(
int
,
int
,
int
)
>
GetStepLookupFunction
(
const
string
&
method
)
override
{
LOG
(
FATAL
)
<<
"["
<<
name_
<<
"] Method not supported"
;
return
nullptr
;
}
void
AdvanceFromPrediction
(
const
float
transition_matrix
[],
int
matrix_length
)
override
{
LOG
(
FATAL
)
<<
"["
<<
name_
<<
"] Method not supported"
;
}
void
AdvanceFromOracle
()
override
{
LOG
(
FATAL
)
<<
"["
<<
name_
<<
"] Method not supported"
;
}
std
::
vector
<
std
::
vector
<
int
>>
GetOracleLabels
()
const
override
{
LOG
(
FATAL
)
<<
"["
<<
name_
<<
"] Method not supported"
;
return
{};
}
int
GetFixedFeatures
(
std
::
function
<
int32
*
(
int
)
>
allocate_indices
,
std
::
function
<
int64
*
(
int
)
>
allocate_ids
,
std
::
function
<
float
*
(
int
)
>
allocate_weights
,
int
channel_id
)
const
override
{
LOG
(
FATAL
)
<<
"["
<<
name_
<<
"] Method not supported"
;
return
0
;
}
int
BulkGetFixedFeatures
(
const
BulkFeatureExtractor
&
extractor
)
override
{
LOG
(
FATAL
)
<<
"["
<<
name_
<<
"] Method not supported"
;
return
0
;
}
std
::
vector
<
LinkFeatures
>
GetRawLinkFeatures
(
int
channel_id
)
const
override
{
LOG
(
FATAL
)
<<
"["
<<
name_
<<
"] Method not supported"
;
return
{};
}
void
AddTranslatedLinkFeaturesToTrace
(
const
std
::
vector
<
LinkFeatures
>
&
features
,
int
channel_id
)
override
{
LOG
(
FATAL
)
<<
"["
<<
name_
<<
"] Method not supported"
;
}
private:
string
name_
;
// component name
int
batch_size_
=
1
;
// number of sentences in current batch
int
beam_size_
=
1
;
// maximum beam size
// Parent states passed to InitializeData(), and passed along in GetBeam().
std
::
vector
<
std
::
vector
<
const
TransitionState
*>>
parent_states_
;
};
REGISTER_DRAGNN_COMPONENT
(
StatelessComponent
);
}
// namespace
}
// namespace dragnn
}
// namespace syntaxnet
syntaxnet/dragnn/components/stateless/stateless_component_test.cc
0 → 100644
View file @
ea3fa4a3
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/core/component_registry.h"
#include "dragnn/core/input_batch_cache.h"
#include "dragnn/core/test/generic.h"
#include "dragnn/core/test/mock_transition_state.h"
#include "dragnn/io/sentence_input_batch.h"
#include "syntaxnet/base.h"
#include "syntaxnet/sentence.pb.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/platform/test.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
{
const
char
kSentence0
[]
=
R"(
token {
word: "Sentence" start: 0 end: 7 tag: "NN" category: "NOUN" label: "ROOT"
break_level: NO_BREAK
}
token {
word: "0" start: 9 end: 9 head: 0 tag: "CD" category: "NUM" label: "num"
break_level: SPACE_BREAK
}
token {
word: "." start: 10 end: 10 head: 0 tag: "." category: "." label: "punct"
break_level: NO_BREAK
}
)"
;
const
char
kSentence1
[]
=
R"(
token {
word: "Sentence" start: 0 end: 7 tag: "NN" category: "NOUN" label: "ROOT"
break_level: NO_BREAK
}
token {
word: "1" start: 9 end: 9 head: 0 tag: "CD" category: "NUM" label: "num"
break_level: SPACE_BREAK
}
token {
word: "." start: 10 end: 10 head: 0 tag: "." category: "." label: "punct"
break_level: NO_BREAK
}
)"
;
const
char
kLongSentence
[]
=
R"(
token {
word: "Sentence" start: 0 end: 7 tag: "NN" category: "NOUN" label: "ROOT"
break_level: NO_BREAK
}
token {
word: "1" start: 9 end: 9 head: 0 tag: "CD" category: "NUM" label: "num"
break_level: SPACE_BREAK
}
token {
word: "2" start: 10 end: 10 head: 0 tag: "CD" category: "NUM" label: "num"
break_level: SPACE_BREAK
}
token {
word: "3" start: 11 end: 11 head: 0 tag: "CD" category: "NUM" label: "num"
break_level: SPACE_BREAK
}
token {
word: "." start: 12 end: 12 head: 0 tag: "." category: "." label: "punct"
break_level: NO_BREAK
}
)"
;
const
char
kMasterSpec
[]
=
R"(
component {
name: "test"
transition_system {
registered_name: "shift-only"
}
linked_feature {
name: "prev"
fml: "input.focus"
embedding_dim: 32
size: 1
source_component: "prev"
source_translator: "identity"
source_layer: "last_layer"
}
backend {
registered_name: "StatelessComponent"
}
}
)"
;
}
// namespace
using
testing
::
Return
;
class
StatelessComponentTest
:
public
::
testing
::
Test
{
public:
std
::
unique_ptr
<
Component
>
CreateParser
(
int
beam_size
,
const
std
::
vector
<
std
::
vector
<
const
TransitionState
*>>
&
states
,
const
std
::
vector
<
string
>
&
data
)
{
MasterSpec
master_spec
;
CHECK
(
TextFormat
::
ParseFromString
(
kMasterSpec
,
&
master_spec
));
data_
.
reset
(
new
InputBatchCache
(
data
));
// Create a parser component with the specified beam size.
std
::
unique_ptr
<
Component
>
parser_component
(
Component
::
Create
(
"StatelessComponent"
));
parser_component
->
InitializeComponent
(
master_spec
.
component
(
0
));
parser_component
->
InitializeData
(
states
,
beam_size
,
data_
.
get
());
return
parser_component
;
}
std
::
unique_ptr
<
InputBatchCache
>
data_
;
};
TEST_F
(
StatelessComponentTest
,
ForwardsTransitionStates
)
{
const
MockTransitionState
mock_state_1
,
mock_state_2
,
mock_state_3
;
const
std
::
vector
<
std
::
vector
<
const
TransitionState
*>>
parent_states
=
{
{},
{
&
mock_state_1
},
{
&
mock_state_2
,
&
mock_state_3
}};
std
::
vector
<
string
>
data
;
for
(
const
string
&
textproto
:
{
kSentence0
,
kSentence1
,
kLongSentence
})
{
Sentence
sentence
;
CHECK
(
TextFormat
::
ParseFromString
(
textproto
,
&
sentence
));
data
.
emplace_back
();
CHECK
(
sentence
.
SerializeToString
(
&
data
.
back
()));
}
CHECK_EQ
(
parent_states
.
size
(),
data
.
size
());
const
int
kBeamSize
=
2
;
auto
test_parser
=
CreateParser
(
kBeamSize
,
parent_states
,
data
);
EXPECT_TRUE
(
test_parser
->
IsReady
());
EXPECT_TRUE
(
test_parser
->
IsTerminal
());
EXPECT_EQ
(
kBeamSize
,
test_parser
->
BeamSize
());
EXPECT_EQ
(
data
.
size
(),
test_parser
->
BatchSize
());
EXPECT_TRUE
(
test_parser
->
GetTraceProtos
().
empty
());
for
(
int
batch_index
=
0
;
batch_index
<
parent_states
.
size
();
++
batch_index
)
{
EXPECT_EQ
(
0
,
test_parser
->
StepsTaken
(
batch_index
));
const
auto
&
beam
=
parent_states
[
batch_index
];
for
(
int
beam_index
=
0
;
beam_index
<
beam
.
size
();
++
beam_index
)
{
// Expect an identity mapping.
EXPECT_EQ
(
beam_index
,
test_parser
->
GetSourceBeamIndex
(
beam_index
,
batch_index
));
}
}
const
auto
forwarded_states
=
test_parser
->
GetBeam
();
EXPECT_EQ
(
parent_states
,
forwarded_states
);
}
}
// namespace dragnn
}
// namespace syntaxnet
syntaxnet/dragnn/components/syntaxnet/BUILD
View file @
ea3fa4a3
package
(
default_visibility
=
[
"//visibility:public"
])
package
(
default_visibility
=
[
"//visibility:public"
],
features
=
[
"-layering_check"
],
)
cc_library
(
name
=
"syntaxnet_component"
,
...
...
@@ -25,7 +28,6 @@ cc_library(
"//syntaxnet:task_context"
,
"//syntaxnet:task_spec_proto"
,
"//syntaxnet:utils"
,
"@org_tensorflow//tensorflow/core:lib"
,
# For tf/core/platform/logging.h
],
alwayslink
=
1
,
)
...
...
@@ -36,10 +38,10 @@ cc_library(
hdrs
=
[
"syntaxnet_link_feature_extractor.h"
],
deps
=
[
"//dragnn/protos:spec_proto"
,
"//syntaxnet:base"
,
"//syntaxnet:embedding_feature_extractor"
,
"//syntaxnet:parser_transitions"
,
"//syntaxnet:task_context"
,
"@org_tensorflow//tensorflow/core:lib"
,
# For tf/core/platform/logging.h
],
)
...
...
@@ -54,7 +56,6 @@ cc_library(
"//dragnn/protos:trace_proto"
,
"//syntaxnet:base"
,
"//syntaxnet:parser_transitions"
,
"@org_tensorflow//tensorflow/core:lib"
,
# For tf/core/platform/logging.h
],
)
...
...
@@ -75,9 +76,9 @@ cc_test(
"//dragnn/core/test:generic"
,
"//dragnn/core/test:mock_transition_state"
,
"//dragnn/io:sentence_input_batch"
,
"//syntaxnet:base"
,
"//syntaxnet:sentence_proto"
,
"@org_tensorflow//tensorflow/core:lib"
,
"@org_tensorflow//tensorflow/core:test"
,
"//syntaxnet:test_main"
,
],
)
...
...
@@ -90,7 +91,6 @@ cc_test(
"//dragnn/protos:spec_proto"
,
"//syntaxnet:task_context"
,
"//syntaxnet:test_main"
,
"@org_tensorflow//tensorflow/core:test"
,
"@org_tensorflow//tensorflow/core:testlib"
,
],
)
...
...
@@ -107,10 +107,9 @@ cc_test(
"//dragnn/core/test:mock_transition_state"
,
"//dragnn/io:sentence_input_batch"
,
"//dragnn/protos:spec_proto"
,
"//syntaxnet:base"
,
"//syntaxnet:sentence_proto"
,
"//syntaxnet:test_main"
,
"@org_tensorflow//tensorflow/core:lib"
,
"@org_tensorflow//tensorflow/core:test"
,
"@org_tensorflow//tensorflow/core:testlib"
,
],
)
syntaxnet/dragnn/components/syntaxnet/syntaxnet_component.cc
View file @
ea3fa4a3
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/components/syntaxnet/syntaxnet_component.h"
#include <vector>
...
...
syntaxnet/dragnn/components/syntaxnet/syntaxnet_component.h
View file @
ea3fa4a3
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#ifndef NLP_SAFT_OPENSOURCE_DRAGNN_COMPONENTS_SYNTAXNET_SYNTAXNET_COMPONENT_H_
#define NLP_SAFT_OPENSOURCE_DRAGNN_COMPONENTS_SYNTAXNET_SYNTAXNET_COMPONENT_H_
...
...
syntaxnet/dragnn/components/syntaxnet/syntaxnet_component_test.cc
View file @
ea3fa4a3
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/components/syntaxnet/syntaxnet_component.h"
#include "dragnn/core/input_batch_cache.h"
...
...
@@ -807,6 +822,90 @@ TEST_F(SyntaxNetComponentTest, ExportsFixedFeatures) {
transition_matrix
[
i
]
=
kTransitionValue
;
}
// Advance twice, so that the underlying parser fills the beam.
test_parser
->
AdvanceFromPrediction
(
transition_matrix
,
kNumPossibleTransitions
*
kBeamSize
*
kBatchSize
);
test_parser
->
AdvanceFromPrediction
(
transition_matrix
,
kNumPossibleTransitions
*
kBeamSize
*
kBatchSize
);
// Get and check the raw link features.
vector
<
int32
>
indices
;
auto
indices_fn
=
[
&
indices
](
int
size
)
{
indices
.
resize
(
size
);
return
indices
.
data
();
};
vector
<
int64
>
ids
;
auto
ids_fn
=
[
&
ids
](
int
size
)
{
ids
.
resize
(
size
);
return
ids
.
data
();
};
vector
<
float
>
weights
;
auto
weights_fn
=
[
&
weights
](
int
size
)
{
weights
.
resize
(
size
);
return
weights
.
data
();
};
constexpr
int
kChannelId
=
0
;
const
int
num_features
=
test_parser
->
GetFixedFeatures
(
indices_fn
,
ids_fn
,
weights_fn
,
kChannelId
);
constexpr
int
kExpectedOutputSize
=
12
;
const
vector
<
int32
>
expected_indices
({
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
,
11
});
const
vector
<
int64
>
expected_ids
({
7
,
50
,
12
,
7
,
12
,
7
,
7
,
50
,
12
,
7
,
12
,
7
});
const
vector
<
float
>
expected_weights
(
{
1.0
,
1.0
,
1.0
,
1.0
,
1.0
,
1.0
,
1.0
,
1.0
,
1.0
,
1.0
,
1.0
,
1.0
});
EXPECT_EQ
(
expected_indices
.
size
(),
kExpectedOutputSize
);
EXPECT_EQ
(
expected_ids
.
size
(),
kExpectedOutputSize
);
EXPECT_EQ
(
expected_weights
.
size
(),
kExpectedOutputSize
);
EXPECT_EQ
(
num_features
,
kExpectedOutputSize
);
EXPECT_EQ
(
expected_indices
,
indices
);
EXPECT_EQ
(
expected_ids
,
ids
);
EXPECT_EQ
(
expected_weights
,
weights
);
}
TEST_F
(
SyntaxNetComponentTest
,
AdvancesAccordingToHighestWeightedInputOption
)
{
// Create an empty input batch and beam vector to initialize the parser.
Sentence
sentence_0
;
TextFormat
::
ParseFromString
(
kSentence0
,
&
sentence_0
);
string
sentence_0_str
;
sentence_0
.
SerializeToString
(
&
sentence_0_str
);
Sentence
sentence_1
;
TextFormat
::
ParseFromString
(
kSentence1
,
&
sentence_1
);
string
sentence_1_str
;
sentence_1
.
SerializeToString
(
&
sentence_1_str
);
constexpr
int
kBeamSize
=
3
;
auto
test_parser
=
CreateParserWithBeamSize
(
kBeamSize
,
{},
{
sentence_0_str
,
sentence_1_str
});
// There are 93 possible transitions for any given state. Create a transition
// array with a score of 10.0 for each transition.
constexpr
int
kBatchSize
=
2
;
constexpr
int
kNumPossibleTransitions
=
93
;
constexpr
float
kTransitionValue
=
10.0
;
float
transition_matrix
[
kNumPossibleTransitions
*
kBeamSize
*
kBatchSize
];
for
(
int
i
=
0
;
i
<
kNumPossibleTransitions
*
kBeamSize
*
kBatchSize
;
++
i
)
{
transition_matrix
[
i
]
=
kTransitionValue
;
}
// Replace the first several options with varying scores to test sorting.
constexpr
int
kBatchOffset
=
kNumPossibleTransitions
*
kBeamSize
;
transition_matrix
[
0
]
=
3
*
kTransitionValue
;
transition_matrix
[
1
]
=
3
*
kTransitionValue
;
transition_matrix
[
2
]
=
4
*
kTransitionValue
;
transition_matrix
[
3
]
=
4
*
kTransitionValue
;
transition_matrix
[
4
]
=
2
*
kTransitionValue
;
transition_matrix
[
5
]
=
2
*
kTransitionValue
;
transition_matrix
[
kBatchOffset
+
0
]
=
3
*
kTransitionValue
;
transition_matrix
[
kBatchOffset
+
1
]
=
3
*
kTransitionValue
;
transition_matrix
[
kBatchOffset
+
2
]
=
4
*
kTransitionValue
;
transition_matrix
[
kBatchOffset
+
3
]
=
4
*
kTransitionValue
;
transition_matrix
[
kBatchOffset
+
4
]
=
2
*
kTransitionValue
;
transition_matrix
[
kBatchOffset
+
5
]
=
2
*
kTransitionValue
;
// Advance twice, so that the underlying parser fills the beam.
test_parser
->
AdvanceFromPrediction
(
transition_matrix
,
kNumPossibleTransitions
*
kBeamSize
*
kBatchSize
);
...
...
@@ -836,7 +935,7 @@ TEST_F(SyntaxNetComponentTest, ExportsFixedFeatures) {
// In this case, all even features and all odd features are identical.
constexpr
int
kExpectedOutputSize
=
12
;
const
vector
<
int32
>
expected_indices
({
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
,
11
});
const
vector
<
int64
>
expected_ids
({
12
,
7
,
12
,
7
,
12
,
7
,
12
,
7
,
12
,
7
,
12
,
7
});
const
vector
<
int64
>
expected_ids
({
12
,
7
,
7
,
50
,
12
,
7
,
12
,
7
,
7
,
50
,
12
,
7
});
const
vector
<
float
>
expected_weights
(
{
1.0
,
1.0
,
1.0
,
1.0
,
1.0
,
1.0
,
1.0
,
1.0
,
1.0
,
1.0
,
1.0
,
1.0
});
...
...
@@ -1024,11 +1123,11 @@ TEST_F(SyntaxNetComponentTest, ExportsRawLinkFeatures) {
EXPECT_EQ
(
link_features
.
size
(),
kBeamSize
*
kBatchSize
*
kNumLinkFeatures
);
// These should index into batch 0.
EXPECT_EQ
(
link_features
.
at
(
0
).
feature_value
(),
-
1
);
EXPECT_EQ
(
link_features
.
at
(
0
).
feature_value
(),
1
);
EXPECT_EQ
(
link_features
.
at
(
0
).
batch_idx
(),
0
);
EXPECT_EQ
(
link_features
.
at
(
0
).
beam_idx
(),
0
);
EXPECT_EQ
(
link_features
.
at
(
1
).
feature_value
(),
-
2
);
EXPECT_EQ
(
link_features
.
at
(
1
).
feature_value
(),
0
);
EXPECT_EQ
(
link_features
.
at
(
1
).
batch_idx
(),
0
);
EXPECT_EQ
(
link_features
.
at
(
1
).
beam_idx
(),
0
);
...
...
@@ -1049,11 +1148,11 @@ TEST_F(SyntaxNetComponentTest, ExportsRawLinkFeatures) {
EXPECT_EQ
(
link_features
.
at
(
5
).
beam_idx
(),
2
);
// These should index into batch 1.
EXPECT_EQ
(
link_features
.
at
(
6
).
feature_value
(),
-
1
);
EXPECT_EQ
(
link_features
.
at
(
6
).
feature_value
(),
1
);
EXPECT_EQ
(
link_features
.
at
(
6
).
batch_idx
(),
1
);
EXPECT_EQ
(
link_features
.
at
(
6
).
beam_idx
(),
0
);
EXPECT_EQ
(
link_features
.
at
(
7
).
feature_value
(),
-
2
);
EXPECT_EQ
(
link_features
.
at
(
7
).
feature_value
(),
0
);
EXPECT_EQ
(
link_features
.
at
(
7
).
batch_idx
(),
1
);
EXPECT_EQ
(
link_features
.
at
(
7
).
beam_idx
(),
0
);
...
...
syntaxnet/dragnn/components/syntaxnet/syntaxnet_link_feature_extractor.cc
View file @
ea3fa4a3
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/components/syntaxnet/syntaxnet_link_feature_extractor.h"
#include "tensorflow/core/platform/logging.h"
...
...
syntaxnet/dragnn/components/syntaxnet/syntaxnet_link_feature_extractor.h
View file @
ea3fa4a3
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#ifndef NLP_SAFT_OPENSOURCE_DRAGNN_COMPONENTS_SYNTAXNET_SYNTAXNET_LINK_FEATURE_EXTRACTOR_H_
#define NLP_SAFT_OPENSOURCE_DRAGNN_COMPONENTS_SYNTAXNET_SYNTAXNET_LINK_FEATURE_EXTRACTOR_H_
...
...
syntaxnet/dragnn/components/syntaxnet/syntaxnet_link_feature_extractor_test.cc
View file @
ea3fa4a3
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/components/syntaxnet/syntaxnet_link_feature_extractor.h"
#include <string>
...
...
syntaxnet/dragnn/components/syntaxnet/syntaxnet_transition_state.cc
View file @
ea3fa4a3
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/components/syntaxnet/syntaxnet_transition_state.h"
#include "tensorflow/core/lib/strings/strcat.h"
...
...
syntaxnet/dragnn/components/syntaxnet/syntaxnet_transition_state.h
View file @
ea3fa4a3
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#ifndef NLP_SAFT_OPENSOURCE_DRAGNN_COMPONENTS_SYNTAXNET_SYNTAXNET_TRANSITION_STATE_H_
#define NLP_SAFT_OPENSOURCE_DRAGNN_COMPONENTS_SYNTAXNET_SYNTAXNET_TRANSITION_STATE_H_
...
...
syntaxnet/dragnn/components/syntaxnet/syntaxnet_transition_state_test.cc
View file @
ea3fa4a3
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/components/syntaxnet/syntaxnet_transition_state.h"
#include "dragnn/components/syntaxnet/syntaxnet_component.h"
...
...
syntaxnet/dragnn/components/util/BUILD
View file @
ea3fa4a3
package
(
default_visibility
=
[
"//visibility:public"
])
package
(
default_visibility
=
[
"//visibility:public"
],
features
=
[
"-layering_check"
],
)
cc_library
(
name
=
"bulk_feature_extractor"
,
hdrs
=
[
"bulk_feature_extractor.h"
],
deps
=
[
"
@org_tensorflow//tensorflow/core:lib
"
,
"
//syntaxnet:base
"
,
],
)
syntaxnet/dragnn/components/util/bulk_feature_extractor.h
View file @
ea3fa4a3
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#ifndef NLP_SAFT_OPENSOURCE_DRAGNN_COMPONENTS_UTIL_BULK_FEATURE_EXTRACTOR_H_
#define NLP_SAFT_OPENSOURCE_DRAGNN_COMPONENTS_UTIL_BULK_FEATURE_EXTRACTOR_H_
...
...
syntaxnet/dragnn/core/BUILD
View file @
ea3fa4a3
package
(
default_visibility
=
[
"//visibility:public"
])
package
(
default_visibility
=
[
"//visibility:public"
],
features
=
[
"-layering_check"
],
)
# Test data.
filegroup
(
...
...
@@ -12,7 +15,7 @@ cc_library(
deps
=
[
"//dragnn/core/interfaces:cloneable_transition_state"
,
"//dragnn/core/interfaces:transition_state"
,
"
@org_tensorflow//tensorflow/core:lib"
,
# For tf/core/platform/logging.h
"
//syntaxnet:base"
,
],
)
...
...
@@ -50,8 +53,8 @@ cc_library(
"//dragnn/protos:data_proto"
,
"//dragnn/protos:spec_proto"
,
"//dragnn/protos:trace_proto"
,
"//syntaxnet:base"
,
"//syntaxnet:registry"
,
"@org_tensorflow//tensorflow/core:lib"
,
# For tf/core/platform/logging.h
],
)
...
...
@@ -64,7 +67,7 @@ cc_library(
":compute_session"
,
":compute_session_impl"
,
"//dragnn/protos:spec_proto"
,
"
@org_tensorflow//tensorflow/core:lib
"
,
"
//syntaxnet:base
"
,
],
)
...
...
@@ -75,7 +78,7 @@ cc_library(
deps
=
[
"//dragnn/core/interfaces:component"
,
"//dragnn/core/interfaces:transition_state"
,
"
@org_tensorflow//tensorflow/core:lib"
,
# For tf/core/platform/logging.h
"
//syntaxnet:base"
,
],
)
...
...
@@ -84,17 +87,14 @@ cc_library(
hdrs
=
[
"input_batch_cache.h"
],
deps
=
[
"//dragnn/core/interfaces:input_batch"
,
"
@org_tensorflow//tensorflow/core:lib"
,
# For tf/core/platform/logging.h
"
//syntaxnet:base"
,
],
)
cc_library
(
name
=
"resource_container"
,
hdrs
=
[
"resource_container.h"
],
deps
=
[
"//syntaxnet:base"
,
"@org_tensorflow//tensorflow/core:framework"
,
],
deps
=
[
"//syntaxnet:base"
],
)
# Tests
...
...
@@ -107,8 +107,8 @@ cc_test(
"//dragnn/core/interfaces:cloneable_transition_state"
,
"//dragnn/core/interfaces:transition_state"
,
"//dragnn/core/test:mock_transition_state"
,
"//syntaxnet:base"
,
"//syntaxnet:test_main"
,
"@org_tensorflow//tensorflow/core:test"
,
],
)
...
...
@@ -125,7 +125,7 @@ cc_test(
"//dragnn/core/test:generic"
,
"//dragnn/core/test:mock_component"
,
"//dragnn/core/test:mock_transition_state"
,
"
@org_tensorflow//tensorflow/core:test
"
,
"
//syntaxnet:base
"
,
],
)
...
...
@@ -138,9 +138,8 @@ cc_test(
"//dragnn/core/test:generic"
,
"//dragnn/core/test:mock_component"
,
"//dragnn/core/test:mock_compute_session"
,
"//syntaxnet:base"
,
"//syntaxnet:test_main"
,
"@org_tensorflow//tensorflow/core:lib"
,
"@org_tensorflow//tensorflow/core:test"
,
],
)
...
...
@@ -151,8 +150,8 @@ cc_test(
":index_translator"
,
"//dragnn/core/test:mock_component"
,
"//dragnn/core/test:mock_transition_state"
,
"//syntaxnet:base"
,
"//syntaxnet:test_main"
,
"@org_tensorflow//tensorflow/core:test"
,
],
)
...
...
@@ -162,8 +161,8 @@ cc_test(
deps
=
[
":input_batch_cache"
,
"//dragnn/core/interfaces:input_batch"
,
"//syntaxnet:base"
,
"//syntaxnet:test_main"
,
"@org_tensorflow//tensorflow/core:test"
,
],
)
...
...
@@ -172,8 +171,8 @@ cc_test(
srcs
=
[
"resource_container_test.cc"
],
deps
=
[
":resource_container"
,
"//syntaxnet:base"
,
"//syntaxnet:test_main"
,
"@org_tensorflow//tensorflow/core:test"
,
],
)
...
...
@@ -213,7 +212,7 @@ cc_library(
deps
=
[
":compute_session"
,
":resource_container"
,
"
@org_tensorflow//tensorflow/core:framework
"
,
"
//syntaxnet:base
"
,
"@org_tensorflow//third_party/eigen3"
,
],
)
...
...
@@ -231,8 +230,7 @@ cc_library(
":resource_container"
,
"//dragnn/protos:data_proto"
,
"//dragnn/protos:spec_proto"
,
"@org_tensorflow//tensorflow/core:framework"
,
"@org_tensorflow//tensorflow/core:lib"
,
"//syntaxnet:base"
,
"@org_tensorflow//third_party/eigen3"
,
],
alwayslink
=
1
,
...
...
@@ -247,8 +245,7 @@ cc_library(
deps
=
[
":compute_session_op"
,
":resource_container"
,
"@org_tensorflow//tensorflow/core:framework"
,
"@org_tensorflow//tensorflow/core:lib"
,
"//syntaxnet:base"
,
"@org_tensorflow//tensorflow/core:protos_all_cc"
,
"@org_tensorflow//third_party/eigen3"
,
],
...
...
@@ -271,8 +268,7 @@ tf_kernel_library(
":resource_container"
,
"//dragnn/protos:data_proto"
,
"//dragnn/protos:spec_proto"
,
"@org_tensorflow//tensorflow/core:framework"
,
"@org_tensorflow//tensorflow/core:lib"
,
"//syntaxnet:base"
,
"@org_tensorflow//third_party/eigen3"
,
],
)
...
...
@@ -292,8 +288,7 @@ tf_kernel_library(
":resource_container"
,
"//dragnn/components/util:bulk_feature_extractor"
,
"//dragnn/protos:spec_proto"
,
"@org_tensorflow//tensorflow/core:framework"
,
"@org_tensorflow//tensorflow/core:lib"
,
"//syntaxnet:base"
,
"@org_tensorflow//tensorflow/core:protos_all_cc"
,
"@org_tensorflow//third_party/eigen3"
,
],
...
...
@@ -311,11 +306,9 @@ cc_test(
":resource_container"
,
"//dragnn/core/test:generic"
,
"//dragnn/core/test:mock_compute_session"
,
"//syntaxnet:base"
,
"//syntaxnet:test_main"
,
"@org_tensorflow//tensorflow/core:framework"
,
"@org_tensorflow//tensorflow/core:protos_all_cc"
,
"@org_tensorflow//tensorflow/core:test"
,
"@org_tensorflow//tensorflow/core:testlib"
,
"@org_tensorflow//tensorflow/core/kernels:ops_testutil"
,
"@org_tensorflow//tensorflow/core/kernels:ops_util"
,
"@org_tensorflow//tensorflow/core/kernels:quantized_ops"
,
...
...
@@ -331,9 +324,8 @@ cc_test(
":resource_container"
,
"//dragnn/components/util:bulk_feature_extractor"
,
"//dragnn/core/test:mock_compute_session"
,
"//syntaxnet:base"
,
"//syntaxnet:test_main"
,
"@org_tensorflow//tensorflow/core:framework"
,
"@org_tensorflow//tensorflow/core:testlib"
,
"@org_tensorflow//tensorflow/core/kernels:ops_testutil"
,
"@org_tensorflow//tensorflow/core/kernels:quantized_ops"
,
],
...
...
syntaxnet/dragnn/core/beam.h
View file @
ea3fa4a3
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#ifndef NLP_SAFT_OPENSOURCE_DRAGNN_CORE_BEAM_H_
#define NLP_SAFT_OPENSOURCE_DRAGNN_CORE_BEAM_H_
#include <algorithm>
#include <cmath>
#include <memory>
#include <vector>
...
...
@@ -112,7 +128,7 @@ class Beam {
CHECK_LT
(
matrix_idx
,
matrix_length
)
<<
"Matrix index out of bounds!"
;
const
double
score_delta
=
transition_matrix
[
matrix_idx
];
CHECK
(
!
isnan
(
score_delta
));
CHECK
(
!
std
::
isnan
(
score_delta
));
candidate
.
source_idx
=
beam_idx
;
candidate
.
action
=
action_idx
;
candidate
.
resulting_score
=
state
->
GetScore
()
+
score_delta
;
...
...
@@ -125,7 +141,7 @@ class Beam {
const
auto
comparator
=
[](
const
Transition
&
a
,
const
Transition
&
b
)
{
return
a
.
resulting_score
>
b
.
resulting_score
;
};
std
::
sort
(
candidates
.
begin
(),
candidates
.
end
(),
comparator
);
std
::
stable_
sort
(
candidates
.
begin
(),
candidates
.
end
(),
comparator
);
// Apply the top transitions, up to a maximum of 'max_size_'.
std
::
vector
<
std
::
unique_ptr
<
T
>>
new_beam
;
...
...
syntaxnet/dragnn/core/beam_test.cc
View file @
ea3fa4a3
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/core/beam.h"
#include "dragnn/core/interfaces/cloneable_transition_state.h"
...
...
syntaxnet/dragnn/core/component_registry.cc
View file @
ea3fa4a3
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/core/component_registry.h"
namespace
syntaxnet
{
...
...
syntaxnet/dragnn/core/component_registry.h
View file @
ea3fa4a3
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#ifndef NLP_SAFT_OPENSOURCE_DRAGNN_CORE_COMPONENT_REGISTRY_H_
#define NLP_SAFT_OPENSOURCE_DRAGNN_CORE_COMPONENT_REGISTRY_H_
...
...
Prev
1
2
3
4
5
6
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment